From cbd16cbd3453e4bcd78c5497128331578559a7f4 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Fri, 29 Nov 2024 10:44:45 +0100 Subject: [PATCH] Rename GlobalWorkspace to GlobalWorkspaceFusion (#189) For now, GlobalWorkspace2Domains should be favoured --- docs/q_and_a.md | 2 +- shimmer/modules/global_workspace.py | 8 ++++---- tests/save_model.py | 4 ++-- tests/test_broadcast.py | 4 ++-- tests/test_ckpt_migrations.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/q_and_a.md b/docs/q_and_a.md index 543e88e1..2d92978f 100644 --- a/docs/q_and_a.md +++ b/docs/q_and_a.md @@ -18,7 +18,7 @@ To get insipiration, you can look at the source code of ## How can I change the loss function? If you are using pre-made GW architecture -([`GlobalWorkspace`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspace), +([`GlobalWorkspace2Domains`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspace2Domains), [`GlobalWorkspaceFusion`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspaceFusion)) and want to update the loss used for demi-cycles, cycles, translations or broadcast, you can do so directly from your definition of the diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 9ca55fe7..b8620273 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -721,8 +721,8 @@ def __init__( ) -class GlobalWorkspace(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]): - """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase. +class GlobalWorkspaceFusion(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]): + """The fusion (with broadcast loss) flavor of GlobalWorkspaceBase. This is used to simplify a Global Workspace instanciation and only overrides the `__init__` method. @@ -841,10 +841,10 @@ def pretrained_global_workspace( fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation function to fuse the domains. **kwargs: additional arguments to pass to - `GlobalWorkspace.load_from_checkpoint`. + `GlobalWorkspace2Domains.load_from_checkpoint`. Returns: - `GlobalWorkspace`: the pretrained `GlobalWorkspace`. + `GlobalWorkspace2Domains`: the pretrained `GlobalWorkspace2Domains`. Raises: `TypeError`: if loaded type is not `GlobalWorkspace`. diff --git a/tests/save_model.py b/tests/save_model.py index e8517e6b..c7e546b3 100644 --- a/tests/save_model.py +++ b/tests/save_model.py @@ -4,7 +4,7 @@ from utils import DummyDomainModule from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder -from shimmer.modules.global_workspace import GlobalWorkspace +from shimmer.modules.global_workspace import GlobalWorkspaceFusion here = Path(__file__).parent @@ -54,7 +54,7 @@ def save_gw_ckpt(): workspace_dim=16, loss_coefs={}, ) - gw = GlobalWorkspace( + gw = GlobalWorkspaceFusion( domains, gw_encoders, gw_decoders, diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index bde29e8d..012e169e 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -4,7 +4,7 @@ from torch import nn from shimmer.modules.domain import DomainModule, LossOutput -from shimmer.modules.global_workspace import GlobalWorkspace +from shimmer.modules.global_workspace import GlobalWorkspaceFusion from shimmer.modules.losses import BroadcastLossCoefs @@ -43,7 +43,7 @@ def test_broadcast_loss(): "contrastives": 0.1, } - gw_fusion = GlobalWorkspace( + gw_fusion = GlobalWorkspaceFusion( domain_mods, gw_encoders, gw_decoders, diff --git a/tests/test_ckpt_migrations.py b/tests/test_ckpt_migrations.py index 09cb98fc..963fb3fc 100644 --- a/tests/test_ckpt_migrations.py +++ b/tests/test_ckpt_migrations.py @@ -6,7 +6,7 @@ from utils import DummyDomainModule from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder -from shimmer.modules.global_workspace import GlobalWorkspace +from shimmer.modules.global_workspace import GlobalWorkspaceFusion from shimmer.utils import MIGRATION_DIR here = Path(__file__).parent @@ -121,7 +121,7 @@ def test_ckpt_migration_gw(): ), } - gw = GlobalWorkspace( + gw = GlobalWorkspaceFusion( domains, gw_encoders, gw_decoders,