Skip to content

Commit

Permalink
Rename GlobalWorkspaceFusion to GlobalWorkspace (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored May 21, 2024
1 parent 85d79a1 commit d2b67c0
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 33 deletions.
12 changes: 6 additions & 6 deletions docs/shimmer_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ to make a GW in shimmer:
![architecture](assets/shimmer_architecture.png)

Let's detail:
- [`DomainModule`](https://bdvllrs.github.io/shimmer/shimmer/modules/domain.html#DomainModule)s
- [`DomainModule`](https://bdvllrs.github.io/shimmer/latest/shimmer/modules/domain.html#DomainModule)s
are the individual domain modules which encode domain data into a latent vector;
- the `GWModule` has access to the domain modules, and defines how to encode, decode and merge representations of the domains into a unique GW representation.
- finally `GlobalWorkspaceBase` takes all building blocks to make a [Pytorch Lightning](https://lightning.ai/docs/pytorch/stable/) module
Expand Down Expand Up @@ -87,7 +87,7 @@ class DomainDataModule(LightningDataModule):
Now that our data module is defined, let's create `DomainModule`s.

## `DomainModule`
For more details about DomainModules, see the [DomainModule API docs](https://bdvllrs.github.io/shimmer/shimmer/modules/domain.html#DomainModule).
For more details about DomainModules, see the [DomainModule API docs](https://bdvllrs.github.io/shimmer/latest/shimmer/modules/domain.html#DomainModule).
The `DomainModule` class extends from a LightningModule and requires you to define some
methods:

Expand Down Expand Up @@ -422,7 +422,7 @@ class GenericDomain(DomainModule):
return LossOutput(loss=F.mse_loss(pred, target))
```

To learn more about LossOutput, see [API docs](https://bdvllrs.github.io/shimmer/shimmer/modules/domain.html#LossOutput).
To learn more about LossOutput, see [API docs](https://bdvllrs.github.io/shimmer/latest/shimmer/modules/domain.html#LossOutput).

## Let's make a GW!

Expand All @@ -436,7 +436,7 @@ from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch import nn

from shimmer import GlobalWorkspace, GWDecoder, GWEncoder, LossCoefs
from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder, LossCoefs
from shimmer.modules.global_workspace import SchedulerArgs


Expand Down Expand Up @@ -508,7 +508,7 @@ def train_gw():

n_epochs = 4

global_workspace = GlobalWorkspace(
global_workspace = GlobalWorkspace2Domains(
domain_mods,
gw_encoders,
gw_decoders,
Expand Down Expand Up @@ -631,7 +631,7 @@ We define loss coefficients for the different losses. Note that `LossCoefs` is a

Finally we make the GlobalWorkspace and train it.
```python
global_workspace = GlobalWorkspace(
global_workspace = GlobalWorkspace2Domains(
domain_mods,
gw_encoders,
gw_decoders,
Expand Down
4 changes: 2 additions & 2 deletions examples/main_example/train_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lightning.pytorch.callbacks import ModelCheckpoint
from torch import nn

from shimmer import GlobalWorkspace, GWDecoder, GWEncoder, LossCoefs
from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder, LossCoefs
from shimmer.modules.global_workspace import SchedulerArgs


Expand Down Expand Up @@ -76,7 +76,7 @@ def train_gw():

n_epochs = 4

global_workspace = GlobalWorkspace(
global_workspace = GlobalWorkspace2Domains(
domain_mods,
gw_encoders,
gw_decoders,
Expand Down
8 changes: 4 additions & 4 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.global_workspace import (
GlobalWorkspace,
GlobalWorkspace2Domains,
GlobalWorkspaceBase,
GlobalWorkspaceBayesian,
GWPredictions,
Expand All @@ -23,7 +23,7 @@
)
from shimmer.modules.losses import (
BroadcastLossCoefs,
GWLosses,
GWLosses2Domains,
GWLossesBase,
GWLossesBayesian,
LossCoefs,
Expand Down Expand Up @@ -68,7 +68,7 @@
"SchedulerArgs",
"GWPredictions",
"GlobalWorkspaceBase",
"GlobalWorkspace",
"GlobalWorkspace2Domains",
"GlobalWorkspaceBayesian",
"pretrained_global_workspace",
"LossOutput",
Expand All @@ -85,7 +85,7 @@
"LossCoefs",
"BroadcastLossCoefs",
"GWLossesBase",
"GWLosses",
"GWLosses2Domains",
"GWLossesBayesian",
"RepeatedDataset",
"batch_cycles",
Expand Down
8 changes: 4 additions & 4 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.global_workspace import (
GlobalWorkspace,
GlobalWorkspace2Domains,
GlobalWorkspaceBase,
GlobalWorkspaceBayesian,
GWPredictions,
Expand All @@ -24,7 +24,7 @@
)
from shimmer.modules.losses import (
BroadcastLossCoefs,
GWLosses,
GWLosses2Domains,
GWLossesBase,
GWLossesBayesian,
LossCoefs,
Expand Down Expand Up @@ -54,7 +54,7 @@
"SchedulerArgs",
"GWPredictions",
"GlobalWorkspaceBase",
"GlobalWorkspace",
"GlobalWorkspace2Domains",
"GlobalWorkspaceBayesian",
"pretrained_global_workspace",
"LossOutput",
Expand All @@ -72,7 +72,7 @@
"LossCoefs",
"BroadcastLossCoefs",
"GWLossesBase",
"GWLosses",
"GWLosses2Domains",
"GWLossesBayesian",
"RepeatedDataset",
"reparameterize",
Expand Down
51 changes: 40 additions & 11 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from shimmer.modules.losses import (
BroadcastLossCoefs,
GWLosses,
GWLosses2Domains,
GWLossesBase,
GWLossesBayesian,
GWLossesFusion,
LossCoefs,
)
from shimmer.modules.selection import (
Expand Down Expand Up @@ -466,7 +466,9 @@ class GWPredictions(GWPredictionsBase):
"""


class GlobalWorkspace(GlobalWorkspaceBase[GWModule, SingleDomainSelection, GWLosses]):
class GlobalWorkspace2Domains(
GlobalWorkspaceBase[GWModule, SingleDomainSelection, GWLosses2Domains]
):
"""
A simple 2-domains max flavor of GlobalWorkspaceBase.
Expand Down Expand Up @@ -519,7 +521,7 @@ def __init__(
torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
)
selection_mod = SingleDomainSelection()
loss_mod = GWLosses(
loss_mod = GWLosses2Domains(
gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
)

Expand Down Expand Up @@ -559,9 +561,7 @@ def forward( # type: ignore
)


class GlobalWorkspaceFusion(
GlobalWorkspaceBase[GWModule, RandomSelection, GWLossesFusion]
):
class GlobalWorkspace(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]):
"""The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
This is used to simplify a Global Workspace instanciation and only overrides the
Expand Down Expand Up @@ -617,7 +617,7 @@ def __init__(
)

selection_mod = RandomSelection(selection_temperature)
loss_mod = GWLossesFusion(
loss_mod = GWLosses(
gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
)

Expand All @@ -630,6 +630,33 @@ def __init__(
scheduler_args,
)

def forward( # type: ignore
self,
latent_domains: LatentsDomainGroupsT,
) -> GWPredictions:
"""
Computes demi-cycles, cycles, and translations.
Args:
latent_domains (`LatentsT`): Groups of domains for the computation.
Returns:
`GWPredictions`: the predictions on the batch.
"""
return GWPredictions(
demi_cycles=batch_demi_cycles(
self.gw_mod, self.selection_mod, latent_domains
),
cycles=batch_cycles(
self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
),
translations=batch_translations(
self.gw_mod, self.selection_mod, latent_domains
),
# TODO: add other combinations
**super().forward(latent_domains),
)


class GlobalWorkspaceBayesian(
GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian]
Expand Down Expand Up @@ -753,7 +780,7 @@ def pretrained_global_workspace(
loss_coefs: LossCoefs,
contrastive_fn: ContrastiveLossType,
**kwargs,
) -> GlobalWorkspace:
) -> GlobalWorkspace2Domains:
"""
Load a `GlobalWorkspace` flavor of `GlobalWorkspaceBase` from a checkpoint.
Expand Down Expand Up @@ -785,16 +812,18 @@ def pretrained_global_workspace(
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
selection_mod = SingleDomainSelection()
loss_mod = GWLosses(gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn)
loss_mod = GWLosses2Domains(
gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn
)

gw = GlobalWorkspace.load_from_checkpoint(
gw = GlobalWorkspace2Domains.load_from_checkpoint(
checkpoint_path,
gw_mod=gw_mod,
selection_mid=selection_mod,
loss_coefs=loss_coefs,
loss_mod=loss_mod,
**kwargs,
)
if not isinstance(gw, GlobalWorkspace):
if not isinstance(gw, GlobalWorkspace2Domains):
raise TypeError("model should be of type GlobalWorkspace")
return gw
4 changes: 2 additions & 2 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ class LossCoefs(TypedDict, total=False):
"""Contrastive loss coefficient."""


class GWLosses(GWLossesBase):
class GWLosses2Domains(GWLossesBase):
"""
Implementation of `GWLossesBase` used for `GWModule`.
"""
Expand Down Expand Up @@ -659,7 +659,7 @@ class BroadcastLossCoefs(TypedDict, total=False):
"""translation loss coefficient. Translation, like cycles, can be many-to-one."""


class GWLossesFusion(GWLossesBase):
class GWLosses(GWLossesBase):
"""
Implementation of `GWLossesBase` for fusion-based models.
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import nn

from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.global_workspace import GlobalWorkspaceFusion
from shimmer.modules.global_workspace import GlobalWorkspace
from shimmer.modules.losses import BroadcastLossCoefs


Expand Down Expand Up @@ -39,7 +39,7 @@ def test_broadcast_loss():
"contrastives": 0.1,
}

gw_fusion = GlobalWorkspaceFusion(
gw_fusion = GlobalWorkspace(
domain_mods,
gw_encoders,
gw_decoders,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch.utils.data
from utils import DummyData, DummyDataset, DummyDomainModule

from shimmer import GlobalWorkspace, GWDecoder, GWEncoder
from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder
from shimmer.modules.selection import SingleDomainSelection


Expand Down Expand Up @@ -61,7 +61,7 @@ def test_training():
),
}

gw = GlobalWorkspace(
gw = GlobalWorkspace2Domains(
domains,
gw_encoders,
gw_decoders,
Expand Down

0 comments on commit d2b67c0

Please sign in to comment.