Skip to content

Commit

Permalink
Add domain modules trained End-to-End (#149)
Browse files Browse the repository at this point in the history
Allow for some domain modules to be trained end-to-end with the global
workspace.

This brings some breaking changes:
1. `DomainModule.compute_loss` and `DomainModule.compute_*_loss` now
require an 3rd
parameter `raw_target: Any` that stores the raw domain input (before
being encoded).
This is usefull for unimodal losses that require the actual inputs to
compute the loss.
2. `GWLossesBase.step` requires a new first argument `raw_data:
RawDomainGroupsT` to
    pass the `raw_targets` to the domain modules.

1) needs to be changed in all projects that implement a `DomainModule`
(every project).
2) has probably less impact as most project won't redefine a Loss
module.
  • Loading branch information
bdvllrs authored Oct 4, 2024
1 parent 7993926 commit 4e011e7
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 42 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,13 @@ refers to `DeterministicGlobalWorkspace`.
[`RandomSelection`](https://ruflab.github.io/shimmer/latest/shimmer/modules/selection.html#RandomSelection)
mechanism. For the old behavior, use
[`GlobalWorkspace2Domains`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspace2Domains).

# 0.6.0
* Allow for some domain modules to be trained end-to-end with the global workspace.
This brings some breaking changes:
1. `DomainModule.compute_loss` and `DomainModule.compute_*_loss` now require an 3rd
parameter `raw_target: Any` that stores the raw domain input (before being encoded).
This is usefull for unimodal losses that require the actual inputs to compute the loss.
2. `GWLossesBase.step` requires a new first argument `raw_data: RawDomainGroupsT` to
pass the `raw_targets` to the domain modules.

11 changes: 8 additions & 3 deletions docs/shimmer_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ methods:
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch.optim.adamw import AdamW
from torch.optim.optimizer import Optimizer

from shimmer import DomainModule

Expand Down Expand Up @@ -134,11 +136,11 @@ class GenericDomain(DomainModule):
self.log("val_loss", loss)
return loss

def configure_optimizers(self) -> torch.optim.Optimizer:
def configure_optimizers(self) -> Optimizer:
"""
Define which optimizer to use
"""
return torch.optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-6)
return AdamW(self.parameters(), lr=1e-3, weight_decay=1e-6)
```

With all this defined, we can make a script to train our unimodal module:
Expand Down Expand Up @@ -385,6 +387,8 @@ We have previously define `GenericDomain` so we can train the module. We now nee
to add some mandatory methods that will be used by the GlobalWorkspace

```python
from typing import Any

from shimmer import LossOutput


Expand Down Expand Up @@ -412,12 +416,13 @@ class GenericDomain(DomainModule):
"""
return self.decoder(z)

def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> LossOutput:
"""
Computes a generic loss in the domain's latent representation.
This must return a LossOutput object. LossOutput is used to separate
the loss used for training the model (given to loss parameter), and
additional metrics that are logged, but not trained on.
The `raw_target` parameter contains the pre-encoded domain data.
"""
return LossOutput(loss=F.mse_loss(pred, target))
```
Expand Down
13 changes: 10 additions & 3 deletions examples/main_example/domains.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Any

import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch.optim.adamw import AdamW
from torch.optim.optimizer import Optimizer

from shimmer import DomainModule, LossOutput

Expand Down Expand Up @@ -40,11 +44,11 @@ def validation_step(
self.log("val_loss", loss)
return loss

def configure_optimizers(self) -> torch.optim.Optimizer:
def configure_optimizers(self) -> Optimizer:
"""
Define which optimizer to use
"""
return torch.optim.AdamW(self.parameters(), lr=1e-3, weight_decay=1e-6)
return AdamW(self.parameters(), lr=1e-3, weight_decay=1e-6)

# shimmer stuff to train the GW

Expand All @@ -66,11 +70,14 @@ def decode(self, z: torch.Tensor) -> torch.Tensor:
"""
return self.decoder(z)

def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
def compute_loss(
self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
) -> LossOutput:
"""
Computes a generic loss in the domain's latent representation.
This must return a LossOutput object. LossOutput is used to separate
the loss used for training the model (given to loss parameter), and
additional metrics that are logged, but not trained on.
The `raw_target` parameter contains the pre-encoded domain data.
"""
return LossOutput(loss=F.mse_loss(pred, target))
54 changes: 45 additions & 9 deletions shimmer/modules/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,24 @@ def __init__(
self.latent_dim = latent_dim
"""The latent dimension of the module."""

self.is_frozen: bool | None = None
""" Whether the module is frozen. If None, it is frozen by default. """

def freeze(self) -> None:
"""
Freezes the module. This is the default mode.
"""
self.is_frozen = True
return super().freeze()

def unfreeze(self) -> None:
"""
Unfreezes the module. This is usefull to train the domain module end-to-end.
This also unlocks `compute_domain_loss` during training.
"""
self.is_frozen = False
return super().unfreeze()

def encode(self, x: Any) -> torch.Tensor:
"""
Encode the domain data into a unimodal representation.
Expand All @@ -94,14 +112,15 @@ def decode(self, z: torch.Tensor) -> Any:
raise NotImplementedError

def compute_loss(
self, pred: torch.Tensor, target: torch.Tensor
self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
) -> LossOutput | None:
"""
Generic loss computation the modality.
Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
raw_target (`Any`): raw data from the input
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
Expand All @@ -110,7 +129,7 @@ def compute_loss(
raise NotImplementedError

def compute_dcy_loss(
self, pred: torch.Tensor, target: torch.Tensor
self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
) -> LossOutput | None:
"""
Computes the loss for a demi-cycle. Override if the demi-cycle loss is
Expand All @@ -119,16 +138,17 @@ def compute_dcy_loss(
Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
raw_target (`Any`): raw data from the input
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
participate in the total loss; it can be used to deactivate
demi-cycle loss for this domain.
"""
return self.compute_loss(pred, target)
return self.compute_loss(pred, target, raw_target)

def compute_cy_loss(
self, pred: torch.Tensor, target: torch.Tensor
self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
) -> LossOutput | None:
"""
Computes the loss for a cycle. Override if the cycle loss is
Expand All @@ -137,16 +157,17 @@ def compute_cy_loss(
Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
raw_target (`Any`): raw data from the input
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
participate in the total loss; it can be used to deactivate
cycle loss for this domain.
"""
return self.compute_loss(pred, target)
return self.compute_loss(pred, target, raw_target)

def compute_tr_loss(
self, pred: torch.Tensor, target: torch.Tensor
self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
) -> LossOutput | None:
"""
Computes the loss for a translation. Override if the translation loss is
Expand All @@ -155,16 +176,17 @@ def compute_tr_loss(
Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
raw_target (`Any`): raw data from the input
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
participate in the total loss; it can be used to deactivate
translation loss for this domain.
"""
return self.compute_loss(pred, target)
return self.compute_loss(pred, target, raw_target)

def compute_fused_loss(
self, pred: torch.Tensor, target: torch.Tensor
self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
) -> LossOutput | None:
"""
Computes the loss for fused (fusion). Override if the fused loss is
Expand All @@ -173,10 +195,24 @@ def compute_fused_loss(
Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
raw_target (`Any`): raw data from the input
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
participate in the total loss; it can be used to deactivate
fused loss for this domain.
"""
return self.compute_loss(pred, target)
return self.compute_loss(pred, target, raw_target)

def compute_domain_loss(self, domain: Any) -> LossOutput | None:
"""
Compute the unimodal domain loss.
Args:
domain (`Any`): domain input
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
participate in the total loss.
"""
return None
39 changes: 36 additions & 3 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR

from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType
from shimmer.modules.domain import DomainModule
from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.gw_module import (
GWModule,
GWModuleBase,
Expand Down Expand Up @@ -484,6 +484,24 @@ def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroup
for domains, latents in latents_domain.items()
}

def unimodal_losses(self, batch: RawDomainGroupsT) -> LossOutput | None:
metrics: dict[str, torch.Tensor] = {}
losses: list[torch.Tensor] = []
for group_domain_names, domain_group in batch.items():
if len(group_domain_names) > 1:
continue
for domain_name, domain in domain_group.items():
domain_mod = self.domain_mods[domain_name]
if not domain_mod.is_frozen:
loss = domain_mod.compute_domain_loss(domain)
if loss is not None:
for name, metric in loss.metrics.items():
metrics[f"{domain_name}/{name}"] = metric
losses.append(loss.loss)
if not len(losses):
return None
return LossOutput(loss=torch.stack(losses, dim=0).sum(), metrics=metrics)

def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> STEP_OUTPUT:
"""
The generic step used in `training_step`, `validation_step` and
Expand All @@ -499,7 +517,7 @@ def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> STEP_OUTPUT
domain_latents = self.encode_domains(batch)
batch_size = groups_batch_size(domain_latents)

loss_output = self.loss_mod.step(domain_latents, mode)
loss_output = self.loss_mod.step(batch, domain_latents, mode)

for name, metric in loss_output.all.items():
self.log(
Expand All @@ -509,6 +527,20 @@ def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> STEP_OUTPUT
add_dataloader_idx=False,
)

total_loss = loss_output.loss

unimodal_losses = self.unimodal_losses(batch)
if unimodal_losses is not None:
for name, metric in unimodal_losses.all.items():
self.log(
f"{mode}/domain_loss/{name}",
metric,
batch_size=batch_size,
add_dataloader_idx=False,
)

total_loss += unimodal_losses.loss

return loss_output.loss

def validation_step( # type: ignore
Expand Down Expand Up @@ -604,7 +636,8 @@ def freeze_domain_modules(
"""

for mod in domain_mods.values():
mod.freeze()
if mod.is_frozen is None:
mod.freeze()
# Cast for better auto-completion at the expense of ModuleDict
return cast(dict[str, DomainModule], ModuleDict(domain_mods))

Expand Down
Loading

0 comments on commit 4e011e7

Please sign in to comment.