From 762a17e42455d8538e82fca501481966f31c8608 Mon Sep 17 00:00:00 2001
From: bdvllrs <bdvllrs@gmail.com>
Date: Fri, 8 Mar 2024 10:44:11 +0000
Subject: [PATCH] Update types and docstring of losses

---
 shimmer/modules/losses.py | 21 +++++++++++----------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py
index 7100f602..154843f8 100644
--- a/shimmer/modules/losses.py
+++ b/shimmer/modules/losses.py
@@ -1,4 +1,5 @@
 from abc import ABC, abstractmethod
+from collections.abc import Mapping
 from typing import TypedDict
 
 import torch
@@ -40,7 +41,7 @@ def step(
 
 def demi_cycle_loss(
     gw_mod: GWModuleBase,
-    domain_mods: dict[str, DomainModule],
+    domain_mods: Mapping[str, DomainModule],
     latent_domains: LatentsDomainGroupsT,
 ) -> dict[str, torch.Tensor]:
     """Computes the demi-cycle loss.
@@ -53,7 +54,7 @@ def demi_cycle_loss(
 
     Args:
         gw_mod (`GWModuleBase`): The GWModule to use
-        domain_mods (`dict[str, DomainModule]`): the domain modules
+        domain_mods (`Mapping[str, DomainModule]`): the domain modules
         latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
 
     Returns:
@@ -82,7 +83,7 @@ def demi_cycle_loss(
 
 def cycle_loss(
     gw_mod: GWModuleBase,
-    domain_mods: dict[str, DomainModule],
+    domain_mods: Mapping[str, DomainModule],
     latent_domains: LatentsDomainGroupsT,
 ) -> dict[str, torch.Tensor]:
     """Computes the cycle loss.
@@ -97,7 +98,7 @@ def cycle_loss(
 
     Args:
         gw_mod (`GWModuleBase`): The GWModule to use
-        domain_mods (`dict[str, DomainModule]`): the domain modules
+        domain_mods (`Mapping[str, DomainModule]`): the domain modules
         latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
 
     Returns:
@@ -137,7 +138,7 @@ def cycle_loss(
 
 def translation_loss(
     gw_mod: GWModuleBase,
-    domain_mods: dict[str, DomainModule],
+    domain_mods: Mapping[str, DomainModule],
     latent_domains: LatentsDomainGroupsT,
 ) -> dict[str, torch.Tensor]:
     """Computes the translation loss.
@@ -153,7 +154,7 @@ def translation_loss(
 
     Args:
         gw_mod (`GWModuleBase`): The GWModule to use
-        domain_mods (`dict[str, DomainModule]`): the domain modules
+        domain_mods (`Mapping[str, DomainModule]`): the domain modules
         latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
 
     Returns:
@@ -374,7 +375,7 @@ def demi_cycle_loss(
     ) -> dict[str, torch.Tensor]:
         """Computes the demi-cycle loss.
 
-        See `shimmer.mdoules.losses.demi_cycle_loss`.
+        See `shimmer.modules.losses.demi_cycle_loss`.
 
         Args:
             latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
@@ -389,7 +390,7 @@ def cycle_loss(
     ) -> dict[str, torch.Tensor]:
         """Computes the cycle loss.
 
-        See `cycle_loss`.
+        See `shimmer.modules.losses.cycle_loss`.
 
         Args:
             latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
@@ -404,7 +405,7 @@ def translation_loss(
     ) -> dict[str, torch.Tensor]:
         """Computes the translation loss.
 
-        See `shimmer.mdoules.losses.translation_loss`.
+        See `shimmer.modules.losses.translation_loss`.
 
         Args:
             latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
@@ -419,7 +420,7 @@ def contrastive_loss(
     ) -> dict[str, torch.Tensor]:
         """Computes the contrastive loss.
 
-        See `shimmer.mdoules.losses.contrastive_loss`.
+        See `shimmer.modules.losses.contrastive_loss`.
 
         Args:
             latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups