From 762a17e42455d8538e82fca501481966f31c8608 Mon Sep 17 00:00:00 2001 From: bdvllrs <bdvllrs@gmail.com> Date: Fri, 8 Mar 2024 10:44:11 +0000 Subject: [PATCH] Update types and docstring of losses --- shimmer/modules/losses.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 7100f602..154843f8 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Mapping from typing import TypedDict import torch @@ -40,7 +41,7 @@ def step( def demi_cycle_loss( gw_mod: GWModuleBase, - domain_mods: dict[str, DomainModule], + domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, ) -> dict[str, torch.Tensor]: """Computes the demi-cycle loss. @@ -53,7 +54,7 @@ def demi_cycle_loss( Args: gw_mod (`GWModuleBase`): The GWModule to use - domain_mods (`dict[str, DomainModule]`): the domain modules + domain_mods (`Mapping[str, DomainModule]`): the domain modules latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups Returns: @@ -82,7 +83,7 @@ def demi_cycle_loss( def cycle_loss( gw_mod: GWModuleBase, - domain_mods: dict[str, DomainModule], + domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, ) -> dict[str, torch.Tensor]: """Computes the cycle loss. @@ -97,7 +98,7 @@ def cycle_loss( Args: gw_mod (`GWModuleBase`): The GWModule to use - domain_mods (`dict[str, DomainModule]`): the domain modules + domain_mods (`Mapping[str, DomainModule]`): the domain modules latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups Returns: @@ -137,7 +138,7 @@ def cycle_loss( def translation_loss( gw_mod: GWModuleBase, - domain_mods: dict[str, DomainModule], + domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, ) -> dict[str, torch.Tensor]: """Computes the translation loss. @@ -153,7 +154,7 @@ def translation_loss( Args: gw_mod (`GWModuleBase`): The GWModule to use - domain_mods (`dict[str, DomainModule]`): the domain modules + domain_mods (`Mapping[str, DomainModule]`): the domain modules latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups Returns: @@ -374,7 +375,7 @@ def demi_cycle_loss( ) -> dict[str, torch.Tensor]: """Computes the demi-cycle loss. - See `shimmer.mdoules.losses.demi_cycle_loss`. + See `shimmer.modules.losses.demi_cycle_loss`. Args: latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups @@ -389,7 +390,7 @@ def cycle_loss( ) -> dict[str, torch.Tensor]: """Computes the cycle loss. - See `cycle_loss`. + See `shimmer.modules.losses.cycle_loss`. Args: latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups @@ -404,7 +405,7 @@ def translation_loss( ) -> dict[str, torch.Tensor]: """Computes the translation loss. - See `shimmer.mdoules.losses.translation_loss`. + See `shimmer.modules.losses.translation_loss`. Args: latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups @@ -419,7 +420,7 @@ def contrastive_loss( ) -> dict[str, torch.Tensor]: """Computes the contrastive loss. - See `shimmer.mdoules.losses.contrastive_loss`. + See `shimmer.modules.losses.contrastive_loss`. Args: latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups