From 4e011e79f307b9a4f9632a350952522d29deb92c Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Fri, 4 Oct 2024 15:53:53 +0200 Subject: [PATCH] Add domain modules trained End-to-End (#149) Allow for some domain modules to be trained end-to-end with the global workspace. This brings some breaking changes: 1. `DomainModule.compute_loss` and `DomainModule.compute_*_loss` now require an 3rd parameter `raw_target: Any` that stores the raw domain input (before being encoded). This is usefull for unimodal losses that require the actual inputs to compute the loss. 2. `GWLossesBase.step` requires a new first argument `raw_data: RawDomainGroupsT` to pass the `raw_targets` to the domain modules. 1) needs to be changed in all projects that implement a `DomainModule` (every project). 2) has probably less impact as most project won't redefine a Loss module. --- CHANGELOG.md | 10 ++++ docs/shimmer_basics.md | 11 +++-- examples/main_example/domains.py | 13 +++-- shimmer/modules/domain.py | 54 ++++++++++++++++---- shimmer/modules/global_workspace.py | 39 +++++++++++++-- shimmer/modules/losses.py | 77 ++++++++++++++++++++--------- tests/test_broadcast.py | 8 ++- tests/test_freeze_domains.py | 69 ++++++++++++++++++++++++++ tests/utils.py | 12 +++++ 9 files changed, 251 insertions(+), 42 deletions(-) create mode 100644 tests/test_freeze_domains.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fc677e47..c9ad7f48 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -71,3 +71,13 @@ refers to `DeterministicGlobalWorkspace`. [`RandomSelection`](https://ruflab.github.io/shimmer/latest/shimmer/modules/selection.html#RandomSelection) mechanism. For the old behavior, use [`GlobalWorkspace2Domains`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspace2Domains). + +# 0.6.0 +* Allow for some domain modules to be trained end-to-end with the global workspace. + This brings some breaking changes: + 1. `DomainModule.compute_loss` and `DomainModule.compute_*_loss` now require an 3rd + parameter `raw_target: Any` that stores the raw domain input (before being encoded). + This is usefull for unimodal losses that require the actual inputs to compute the loss. + 2. `GWLossesBase.step` requires a new first argument `raw_data: RawDomainGroupsT` to + pass the `raw_targets` to the domain modules. + diff --git a/docs/shimmer_basics.md b/docs/shimmer_basics.md index aaced901..afcd772f 100644 --- a/docs/shimmer_basics.md +++ b/docs/shimmer_basics.md @@ -95,6 +95,8 @@ methods: import torch import torch.nn.functional as F from torch.nn import Linear +from torch.optim.adamw import AdamW +from torch.optim.optimizer import Optimizer from shimmer import DomainModule @@ -134,11 +136,11 @@ class GenericDomain(DomainModule): self.log("val_loss", loss) return loss - def configure_optimizers(self) -> torch.optim.Optimizer: + def configure_optimizers(self) -> Optimizer: """ Define which optimizer to use """ - return torch.optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-6) + return AdamW(self.parameters(), lr=1e-3, weight_decay=1e-6) ``` With all this defined, we can make a script to train our unimodal module: @@ -385,6 +387,8 @@ We have previously define `GenericDomain` so we can train the module. We now nee to add some mandatory methods that will be used by the GlobalWorkspace ```python +from typing import Any + from shimmer import LossOutput @@ -412,12 +416,13 @@ class GenericDomain(DomainModule): """ return self.decoder(z) - def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + def compute_loss(self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> LossOutput: """ Computes a generic loss in the domain's latent representation. This must return a LossOutput object. LossOutput is used to separate the loss used for training the model (given to loss parameter), and additional metrics that are logged, but not trained on. + The `raw_target` parameter contains the pre-encoded domain data. """ return LossOutput(loss=F.mse_loss(pred, target)) ``` diff --git a/examples/main_example/domains.py b/examples/main_example/domains.py index 40d0ac13..bb8611f1 100644 --- a/examples/main_example/domains.py +++ b/examples/main_example/domains.py @@ -1,6 +1,10 @@ +from typing import Any + import torch import torch.nn.functional as F from torch.nn import Linear +from torch.optim.adamw import AdamW +from torch.optim.optimizer import Optimizer from shimmer import DomainModule, LossOutput @@ -40,11 +44,11 @@ def validation_step( self.log("val_loss", loss) return loss - def configure_optimizers(self) -> torch.optim.Optimizer: + def configure_optimizers(self) -> Optimizer: """ Define which optimizer to use """ - return torch.optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-6) + return AdamW(self.parameters(), lr=1e-3, weight_decay=1e-6) # shimmer stuff to train the GW @@ -66,11 +70,14 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: """ return self.decoder(z) - def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + def compute_loss( + self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any + ) -> LossOutput: """ Computes a generic loss in the domain's latent representation. This must return a LossOutput object. LossOutput is used to separate the loss used for training the model (given to loss parameter), and additional metrics that are logged, but not trained on. + The `raw_target` parameter contains the pre-encoded domain data. """ return LossOutput(loss=F.mse_loss(pred, target)) diff --git a/shimmer/modules/domain.py b/shimmer/modules/domain.py index fda7dbc2..c909f010 100644 --- a/shimmer/modules/domain.py +++ b/shimmer/modules/domain.py @@ -71,6 +71,24 @@ def __init__( self.latent_dim = latent_dim """The latent dimension of the module.""" + self.is_frozen: bool | None = None + """ Whether the module is frozen. If None, it is frozen by default. """ + + def freeze(self) -> None: + """ + Freezes the module. This is the default mode. + """ + self.is_frozen = True + return super().freeze() + + def unfreeze(self) -> None: + """ + Unfreezes the module. This is usefull to train the domain module end-to-end. + This also unlocks `compute_domain_loss` during training. + """ + self.is_frozen = False + return super().unfreeze() + def encode(self, x: Any) -> torch.Tensor: """ Encode the domain data into a unimodal representation. @@ -94,7 +112,7 @@ def decode(self, z: torch.Tensor) -> Any: raise NotImplementedError def compute_loss( - self, pred: torch.Tensor, target: torch.Tensor + self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any ) -> LossOutput | None: """ Generic loss computation the modality. @@ -102,6 +120,7 @@ def compute_loss( Args: pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor + raw_target (`Any`): raw data from the input Results: `LossOutput | None`: LossOuput with training loss and additional metrics. If `None` is returned, this loss will be ignored and will not @@ -110,7 +129,7 @@ def compute_loss( raise NotImplementedError def compute_dcy_loss( - self, pred: torch.Tensor, target: torch.Tensor + self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any ) -> LossOutput | None: """ Computes the loss for a demi-cycle. Override if the demi-cycle loss is @@ -119,16 +138,17 @@ def compute_dcy_loss( Args: pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor + raw_target (`Any`): raw data from the input Results: `LossOutput | None`: LossOuput with training loss and additional metrics. If `None` is returned, this loss will be ignored and will not participate in the total loss; it can be used to deactivate demi-cycle loss for this domain. """ - return self.compute_loss(pred, target) + return self.compute_loss(pred, target, raw_target) def compute_cy_loss( - self, pred: torch.Tensor, target: torch.Tensor + self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any ) -> LossOutput | None: """ Computes the loss for a cycle. Override if the cycle loss is @@ -137,16 +157,17 @@ def compute_cy_loss( Args: pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor + raw_target (`Any`): raw data from the input Results: `LossOutput | None`: LossOuput with training loss and additional metrics. If `None` is returned, this loss will be ignored and will not participate in the total loss; it can be used to deactivate cycle loss for this domain. """ - return self.compute_loss(pred, target) + return self.compute_loss(pred, target, raw_target) def compute_tr_loss( - self, pred: torch.Tensor, target: torch.Tensor + self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any ) -> LossOutput | None: """ Computes the loss for a translation. Override if the translation loss is @@ -155,16 +176,17 @@ def compute_tr_loss( Args: pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor + raw_target (`Any`): raw data from the input Results: `LossOutput | None`: LossOuput with training loss and additional metrics. If `None` is returned, this loss will be ignored and will not participate in the total loss; it can be used to deactivate translation loss for this domain. """ - return self.compute_loss(pred, target) + return self.compute_loss(pred, target, raw_target) def compute_fused_loss( - self, pred: torch.Tensor, target: torch.Tensor + self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any ) -> LossOutput | None: """ Computes the loss for fused (fusion). Override if the fused loss is @@ -173,10 +195,24 @@ def compute_fused_loss( Args: pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor + raw_target (`Any`): raw data from the input Results: `LossOutput | None`: LossOuput with training loss and additional metrics. If `None` is returned, this loss will be ignored and will not participate in the total loss; it can be used to deactivate fused loss for this domain. """ - return self.compute_loss(pred, target) + return self.compute_loss(pred, target, raw_target) + + def compute_domain_loss(self, domain: Any) -> LossOutput | None: + """ + Compute the unimodal domain loss. + + Args: + domain (`Any`): domain input + Results: + `LossOutput | None`: LossOuput with training loss and additional metrics. + If `None` is returned, this loss will be ignored and will not + participate in the total loss. + """ + return None diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index c0df2fa3..ee03690a 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -11,7 +11,7 @@ from torch.optim.lr_scheduler import LRScheduler, OneCycleLR from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType -from shimmer.modules.domain import DomainModule +from shimmer.modules.domain import DomainModule, LossOutput from shimmer.modules.gw_module import ( GWModule, GWModuleBase, @@ -484,6 +484,24 @@ def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroup for domains, latents in latents_domain.items() } + def unimodal_losses(self, batch: RawDomainGroupsT) -> LossOutput | None: + metrics: dict[str, torch.Tensor] = {} + losses: list[torch.Tensor] = [] + for group_domain_names, domain_group in batch.items(): + if len(group_domain_names) > 1: + continue + for domain_name, domain in domain_group.items(): + domain_mod = self.domain_mods[domain_name] + if not domain_mod.is_frozen: + loss = domain_mod.compute_domain_loss(domain) + if loss is not None: + for name, metric in loss.metrics.items(): + metrics[f"{domain_name}/{name}"] = metric + losses.append(loss.loss) + if not len(losses): + return None + return LossOutput(loss=torch.stack(losses, dim=0).sum(), metrics=metrics) + def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> STEP_OUTPUT: """ The generic step used in `training_step`, `validation_step` and @@ -499,7 +517,7 @@ def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> STEP_OUTPUT domain_latents = self.encode_domains(batch) batch_size = groups_batch_size(domain_latents) - loss_output = self.loss_mod.step(domain_latents, mode) + loss_output = self.loss_mod.step(batch, domain_latents, mode) for name, metric in loss_output.all.items(): self.log( @@ -509,6 +527,20 @@ def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> STEP_OUTPUT add_dataloader_idx=False, ) + total_loss = loss_output.loss + + unimodal_losses = self.unimodal_losses(batch) + if unimodal_losses is not None: + for name, metric in unimodal_losses.all.items(): + self.log( + f"{mode}/domain_loss/{name}", + metric, + batch_size=batch_size, + add_dataloader_idx=False, + ) + + total_loss += unimodal_losses.loss + return loss_output.loss def validation_step( # type: ignore @@ -604,7 +636,8 @@ def freeze_domain_modules( """ for mod in domain_mods.values(): - mod.freeze() + if mod.is_frozen is None: + mod.freeze() # Cast for better auto-completion at the expense of ModuleDict return cast(dict[str, DomainModule], ModuleDict(domain_mods)) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 49b60924..6ff9c116 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -13,7 +13,7 @@ GWModuleBayesian, ) from shimmer.modules.selection import SelectionBase -from shimmer.types import LatentsDomainGroupsT, ModelModeT +from shimmer.types import LatentsDomainGroupsT, ModelModeT, RawDomainGroupsT class GWLossesBase(torch.nn.Module, ABC): @@ -26,6 +26,7 @@ class GWLossesBase(torch.nn.Module, ABC): @abstractmethod def step( self, + raw_data: RawDomainGroupsT, domain_latents: LatentsDomainGroupsT, mode: ModelModeT, ) -> LossOutput: @@ -33,6 +34,7 @@ def step( Computes the losses. Args: + raw_data (`RawDomainGroupsT`): raw input data domain_latents (`LatentsDomainGroupsT`): All latent groups mode (`Literal["train", "val", "test", "val/ood", "test/ood"]`): model mode Returns: @@ -46,6 +48,7 @@ def demi_cycle_loss( selection_mod: SelectionBase, domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, + raw_data: RawDomainGroupsT, ) -> dict[str, torch.Tensor]: """ Computes the demi-cycle loss. @@ -62,6 +65,7 @@ def demi_cycle_loss( domain_mods (`Mapping[str, DomainModule]`): the domain modules latent_domains (`shimmer.types.LatentsDomainGroupsT`): the latent unimodal groups + raw_data (`RawDomainGroupsT`): raw input data Returns: `dict[str, torch.Tensor]`: a dict of metrics. @@ -77,7 +81,9 @@ def demi_cycle_loss( x_recons = gw_mod.decode( gw_mod.encode_and_fuse(latents, selection_mod), domains={domain_name} )[domain_name] - loss_output = domain_mod.compute_dcy_loss(x_recons, latents[domain_name]) + loss_output = domain_mod.compute_dcy_loss( + x_recons, latents[domain_name], raw_data[domains][domain_name] + ) if loss_output is None: continue losses[f"demi_cycle_{domain_name}"] = loss_output.loss @@ -94,6 +100,7 @@ def cycle_loss( selection_mod: SelectionBase, domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, + raw_data: RawDomainGroupsT, ) -> dict[str, torch.Tensor]: """ Computes the cycle loss. @@ -111,6 +118,7 @@ def cycle_loss( selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use domain_mods (`Mapping[str, DomainModule]`): the domain modules latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups + raw_data (`RawDomainGroupsT`): raw input data Returns: `dict[str, torch.Tensor]`: a dict of metrics. @@ -139,6 +147,7 @@ def cycle_loss( loss_output = domain_mod.compute_cy_loss( x_recons[domain_name_source], latents_source[domain_name_source], + raw_data[domains_source][domain_name_source], ) if loss_output is None: continue @@ -157,6 +166,7 @@ def translation_loss( selection_mod: SelectionBase, domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, + raw_data: RawDomainGroupsT, ) -> dict[str, torch.Tensor]: """ Computes the translation loss. @@ -174,6 +184,7 @@ def translation_loss( gw_mod (`GWModuleBase`): The GWModule to use domain_mods (`Mapping[str, DomainModule]`): the domain modules latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups + raw_data (`RawDomainGroupsT`): raw input data Returns: `dict[str, torch.Tensor]`: a dict of metrics. @@ -204,6 +215,7 @@ def translation_loss( loss_output = mod.compute_tr_loss( prediction, latents[domain_name_target], + raw_data[domains][domain_name_target], ) if loss_output is None: continue @@ -396,7 +408,7 @@ def __init__( self.contrastive_fn = contrastive_fn def demi_cycle_loss( - self, latent_domains: LatentsDomainGroupsT + self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT ) -> dict[str, torch.Tensor]: """ Computes the demi-cycle loss. @@ -405,16 +417,17 @@ def demi_cycle_loss( Args: latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups + raw_data (`RawDomainGroupsT`): raw input data Returns: `dict[str, torch.Tensor]`: a dict of metrics. """ return demi_cycle_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains + self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) def cycle_loss( - self, latent_domains: LatentsDomainGroupsT + self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT ) -> dict[str, torch.Tensor]: """ Computes the cycle loss. @@ -423,16 +436,17 @@ def cycle_loss( Args: latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups + raw_data (`RawDomainGroupsT`): raw input data Returns: `dict[str, torch.Tensor]`: a dict of metrics. """ return cycle_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains + self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) def translation_loss( - self, latent_domains: LatentsDomainGroupsT + self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT ) -> dict[str, torch.Tensor]: """ Computes the translation loss. @@ -441,12 +455,13 @@ def translation_loss( Args: latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups + raw_data (`RawDomainGroupsT`): raw input data Returns: `dict[str, torch.Tensor]`: a dict of metrics. """ return translation_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains + self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) def contrastive_loss( @@ -466,7 +481,10 @@ def contrastive_loss( return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) def step( - self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT + self, + raw_data: RawDomainGroupsT, + domain_latents: LatentsDomainGroupsT, + mode: ModelModeT, ) -> LossOutput: """ Computes and returns the losses @@ -478,6 +496,7 @@ def step( - Contrastive metrics (see `GWLosses.contrastive_loss`) Args: + raw_data (`RawDomainGroupsT`): raw input data domain_latents (`LatentsDomainGroupsT`): All latent groups mode (`ModelModeT`): model mode Returns: @@ -485,9 +504,9 @@ def step( """ metrics: dict[str, torch.Tensor] = {} - metrics.update(self.demi_cycle_loss(domain_latents)) - metrics.update(self.cycle_loss(domain_latents)) - metrics.update(self.translation_loss(domain_latents)) + metrics.update(self.demi_cycle_loss(domain_latents, raw_data)) + metrics.update(self.cycle_loss(domain_latents, raw_data)) + metrics.update(self.translation_loss(domain_latents, raw_data)) metrics.update(self.contrastive_loss(domain_latents)) loss = torch.stack( @@ -524,6 +543,7 @@ def broadcast_loss( selection_mod: SelectionBase, domain_mods: Mapping[str, DomainModule], latent_domains: LatentsDomainGroupsT, + raw_data: RawDomainGroupsT, ) -> dict[str, torch.Tensor]: """ Computes broadcast loss including demi-cycle, cycle, and translation losses. @@ -533,6 +553,7 @@ def broadcast_loss( selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use domain_mods (`Mapping[str, DomainModule]`): the domain modules latent_domains: The latent domain representations. + raw_data (`RawDomainGroupsT`): raw input data Returns: A dictionary with the total loss and additional metrics. @@ -581,7 +602,9 @@ def broadcast_loss( else: loss_fn = domain_mods[domain].compute_fused_loss - loss_output = loss_fn(pred, ground_truth) + loss_output = loss_fn( + pred, ground_truth, raw_data[group_domains][domain] + ) if loss_output is None: continue @@ -621,7 +644,9 @@ def broadcast_loss( for domain in selected_latents: re_ground_truth = latents[domain] re_loss_output = domain_mods[domain].compute_cy_loss( - re_decoded_latents[domain], re_ground_truth + re_decoded_latents[domain], + re_ground_truth, + raw_data[group_domains][domain], ) if re_loss_output is None: continue @@ -731,19 +756,23 @@ def contrastive_loss( return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) def broadcast_loss( - self, latent_domains: LatentsDomainGroupsT + self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT ) -> dict[str, torch.Tensor]: return broadcast_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains + self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) def step( - self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT + self, + raw_data: RawDomainGroupsT, + domain_latents: LatentsDomainGroupsT, + mode: ModelModeT, ) -> LossOutput: """ Performs a step of loss computation. Args: + raw_data (`RawDomainGroupsT`): raw input data domain_latents: Latent representations for all domains. mode: The mode in which the model is currently operating. @@ -754,7 +783,7 @@ def step( metrics: dict[str, torch.Tensor] = {} metrics.update(self.contrastive_loss(domain_latents)) - metrics.update(self.broadcast_loss(domain_latents)) + metrics.update(self.broadcast_loss(domain_latents, raw_data)) loss = torch.stack( [ @@ -845,19 +874,23 @@ def contrastive_loss( return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) def broadcast_loss( - self, latent_domains: LatentsDomainGroupsT + self, latent_domains: LatentsDomainGroupsT, raw_data: RawDomainGroupsT ) -> dict[str, torch.Tensor]: return broadcast_loss( - self.gw_mod, self.selection_mod, self.domain_mods, latent_domains + self.gw_mod, self.selection_mod, self.domain_mods, latent_domains, raw_data ) def step( - self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT + self, + raw_data: RawDomainGroupsT, + domain_latents: LatentsDomainGroupsT, + mode: ModelModeT, ) -> LossOutput: """ Performs a step of loss computation. Args: + raw_data (`RawDomainGroupsT`): raw input data domain_latents: Latent representations for all domains. mode: The mode in which the model is currently operating. @@ -868,7 +901,7 @@ def step( metrics: dict[str, torch.Tensor] = {} metrics.update(self.contrastive_loss(domain_latents)) - metrics.update(self.broadcast_loss(domain_latents)) + metrics.update(self.broadcast_loss(domain_latents, raw_data)) loss = torch.stack( [ diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index a3c8fbf1..bde29e8d 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from torch import nn @@ -18,7 +20,9 @@ def encode(self, x: torch.Tensor) -> torch.Tensor: def decode(self, z: torch.Tensor) -> torch.Tensor: return self.decoder(z) # Simple forward pass through decoder - def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + def compute_loss( + self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any + ) -> LossOutput: loss = torch.mean((pred - target) ** 2) # Simple MSE loss return LossOutput(loss=loss) # Constructing LossOutput with the loss @@ -62,7 +66,7 @@ def test_broadcast_loss(): } # Test broadcast_loss with the corrected structure - output = gw_fusion.loss_mod.broadcast_loss(latent_domains) + output = gw_fusion.loss_mod.broadcast_loss(latent_domains, latent_domains) er_msg = "Demi-cycle, cycle, fused and translation metrics should be in the output." assert all( diff --git a/tests/test_freeze_domains.py b/tests/test_freeze_domains.py new file mode 100644 index 00000000..9e94549c --- /dev/null +++ b/tests/test_freeze_domains.py @@ -0,0 +1,69 @@ +from utils import DummyDomainModuleWithParams + +from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder + + +def test_training(): + domains = { + "v": DummyDomainModuleWithParams(latent_dim=128), + "t": DummyDomainModuleWithParams(latent_dim=128), + "a": DummyDomainModuleWithParams(latent_dim=128), + } + + domains["a"].unfreeze() + + workspace_dim = 16 + + gw_encoders = { + "v": GWEncoder( + domains["v"].latent_dim, + hidden_dim=64, + out_dim=workspace_dim, + n_layers=1, + ), + "t": GWEncoder( + domains["t"].latent_dim, + hidden_dim=64, + out_dim=workspace_dim, + n_layers=1, + ), + "a": GWEncoder( + domains["a"].latent_dim, + hidden_dim=64, + out_dim=workspace_dim, + n_layers=1, + ), + } + + gw_decoders = { + "v": GWDecoder( + workspace_dim, + hidden_dim=64, + out_dim=domains["v"].latent_dim, + n_layers=1, + ), + "t": GWDecoder( + workspace_dim, + hidden_dim=64, + out_dim=domains["t"].latent_dim, + n_layers=1, + ), + "a": GWDecoder( + workspace_dim, + hidden_dim=64, + out_dim=domains["a"].latent_dim, + n_layers=1, + ), + } + + gw = GlobalWorkspace2Domains( + domains, + gw_encoders, + gw_decoders, + workspace_dim=16, + loss_coefs={}, + ) + assert gw.domain_mods["v"].is_frozen + assert not gw.domain_mods["a"].is_frozen + assert not len([p for p in gw.domain_mods["v"].parameters() if p.requires_grad]) + assert len([p for p in gw.domain_mods["a"].parameters() if p.requires_grad]) diff --git a/tests/utils.py b/tests/utils.py index 8ff832ce..9074ea84 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -36,3 +36,15 @@ def encode(self, x: DummyData) -> torch.Tensor: def decode(self, z: torch.Tensor) -> DummyData: return DummyData(vec=z) + + +class DummyDomainModuleWithParams(DomainModule): + def __init__(self, latent_dim: int) -> None: + super().__init__(latent_dim) + self.net = torch.nn.Linear(1, 1) + + def encode(self, x: DummyData) -> torch.Tensor: + return x.vec + + def decode(self, z: torch.Tensor) -> DummyData: + return DummyData(vec=z)