Skip to content

Commit

Permalink
Update types and docstring of losses
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Mar 8, 2024
1 parent e636ef4 commit 762a17e
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import TypedDict

import torch
Expand Down Expand Up @@ -40,7 +41,7 @@ def step(

def demi_cycle_loss(
gw_mod: GWModuleBase,
domain_mods: dict[str, DomainModule],
domain_mods: Mapping[str, DomainModule],
latent_domains: LatentsDomainGroupsT,
) -> dict[str, torch.Tensor]:
"""Computes the demi-cycle loss.
Expand All @@ -53,7 +54,7 @@ def demi_cycle_loss(
Args:
gw_mod (`GWModuleBase`): The GWModule to use
domain_mods (`dict[str, DomainModule]`): the domain modules
domain_mods (`Mapping[str, DomainModule]`): the domain modules
latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
Returns:
Expand Down Expand Up @@ -82,7 +83,7 @@ def demi_cycle_loss(

def cycle_loss(
gw_mod: GWModuleBase,
domain_mods: dict[str, DomainModule],
domain_mods: Mapping[str, DomainModule],
latent_domains: LatentsDomainGroupsT,
) -> dict[str, torch.Tensor]:
"""Computes the cycle loss.
Expand All @@ -97,7 +98,7 @@ def cycle_loss(
Args:
gw_mod (`GWModuleBase`): The GWModule to use
domain_mods (`dict[str, DomainModule]`): the domain modules
domain_mods (`Mapping[str, DomainModule]`): the domain modules
latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
Returns:
Expand Down Expand Up @@ -137,7 +138,7 @@ def cycle_loss(

def translation_loss(
gw_mod: GWModuleBase,
domain_mods: dict[str, DomainModule],
domain_mods: Mapping[str, DomainModule],
latent_domains: LatentsDomainGroupsT,
) -> dict[str, torch.Tensor]:
"""Computes the translation loss.
Expand All @@ -153,7 +154,7 @@ def translation_loss(
Args:
gw_mod (`GWModuleBase`): The GWModule to use
domain_mods (`dict[str, DomainModule]`): the domain modules
domain_mods (`Mapping[str, DomainModule]`): the domain modules
latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
Returns:
Expand Down Expand Up @@ -374,7 +375,7 @@ def demi_cycle_loss(
) -> dict[str, torch.Tensor]:
"""Computes the demi-cycle loss.
See `shimmer.mdoules.losses.demi_cycle_loss`.
See `shimmer.modules.losses.demi_cycle_loss`.
Args:
latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
Expand All @@ -389,7 +390,7 @@ def cycle_loss(
) -> dict[str, torch.Tensor]:
"""Computes the cycle loss.
See `cycle_loss`.
See `shimmer.modules.losses.cycle_loss`.
Args:
latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
Expand All @@ -404,7 +405,7 @@ def translation_loss(
) -> dict[str, torch.Tensor]:
"""Computes the translation loss.
See `shimmer.mdoules.losses.translation_loss`.
See `shimmer.modules.losses.translation_loss`.
Args:
latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
Expand All @@ -419,7 +420,7 @@ def contrastive_loss(
) -> dict[str, torch.Tensor]:
"""Computes the contrastive loss.
See `shimmer.mdoules.losses.contrastive_loss`.
See `shimmer.modules.losses.contrastive_loss`.
Args:
latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
Expand Down

0 comments on commit 762a17e

Please sign in to comment.