Skip to content

Commit

Permalink
Rename GlobalWorkspace to GlobalWorkspaceFusion (#189)
Browse files Browse the repository at this point in the history
For now, GlobalWorkspace2Domains should be favoured
  • Loading branch information
bdvllrs authored Nov 29, 2024
1 parent 7c9bb5f commit cbd16cb
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/q_and_a.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ To get insipiration, you can look at the source code of

## How can I change the loss function?
If you are using pre-made GW architecture
([`GlobalWorkspace`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspace),
([`GlobalWorkspace2Domains`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspace2Domains),
[`GlobalWorkspaceFusion`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspaceFusion)) and want to update the loss
used for demi-cycles, cycles, translations or broadcast, you can do so directly from
your definition of the
Expand Down
8 changes: 4 additions & 4 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,8 +721,8 @@ def __init__(
)


class GlobalWorkspace(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]):
"""The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
class GlobalWorkspaceFusion(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]):
"""The fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
This is used to simplify a Global Workspace instanciation and only overrides the
`__init__` method.
Expand Down Expand Up @@ -841,10 +841,10 @@ def pretrained_global_workspace(
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
function to fuse the domains.
**kwargs: additional arguments to pass to
`GlobalWorkspace.load_from_checkpoint`.
`GlobalWorkspace2Domains.load_from_checkpoint`.
Returns:
`GlobalWorkspace`: the pretrained `GlobalWorkspace`.
`GlobalWorkspace2Domains`: the pretrained `GlobalWorkspace2Domains`.
Raises:
`TypeError`: if loaded type is not `GlobalWorkspace`.
Expand Down
4 changes: 2 additions & 2 deletions tests/save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from utils import DummyDomainModule

from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder
from shimmer.modules.global_workspace import GlobalWorkspace
from shimmer.modules.global_workspace import GlobalWorkspaceFusion

here = Path(__file__).parent

Expand Down Expand Up @@ -54,7 +54,7 @@ def save_gw_ckpt():
workspace_dim=16,
loss_coefs={},
)
gw = GlobalWorkspace(
gw = GlobalWorkspaceFusion(
domains,
gw_encoders,
gw_decoders,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn

from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.global_workspace import GlobalWorkspace
from shimmer.modules.global_workspace import GlobalWorkspaceFusion
from shimmer.modules.losses import BroadcastLossCoefs


Expand Down Expand Up @@ -43,7 +43,7 @@ def test_broadcast_loss():
"contrastives": 0.1,
}

gw_fusion = GlobalWorkspace(
gw_fusion = GlobalWorkspaceFusion(
domain_mods,
gw_encoders,
gw_decoders,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ckpt_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from utils import DummyDomainModule

from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder
from shimmer.modules.global_workspace import GlobalWorkspace
from shimmer.modules.global_workspace import GlobalWorkspaceFusion
from shimmer.utils import MIGRATION_DIR

here = Path(__file__).parent
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_ckpt_migration_gw():
),
}

gw = GlobalWorkspace(
gw = GlobalWorkspaceFusion(
domains,
gw_encoders,
gw_decoders,
Expand Down

0 comments on commit cbd16cb

Please sign in to comment.