Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename GlobalWorkspace to GlobalWorkspaceFusion #189

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/q_and_a.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ To get insipiration, you can look at the source code of

## How can I change the loss function?
If you are using pre-made GW architecture
([`GlobalWorkspace`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspace),
([`GlobalWorkspace2Domains`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspace2Domains),
[`GlobalWorkspaceFusion`](https://ruflab.github.io/shimmer/latest/shimmer/modules/global_workspace.html#GlobalWorkspaceFusion)) and want to update the loss
used for demi-cycles, cycles, translations or broadcast, you can do so directly from
your definition of the
Expand Down
8 changes: 4 additions & 4 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,8 +721,8 @@ def __init__(
)


class GlobalWorkspace(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]):
"""The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
class GlobalWorkspaceFusion(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]):
"""The fusion (with broadcast loss) flavor of GlobalWorkspaceBase.

This is used to simplify a Global Workspace instanciation and only overrides the
`__init__` method.
Expand Down Expand Up @@ -841,10 +841,10 @@ def pretrained_global_workspace(
fusion_activation_fn (`Callable[[torch.Tensor], torch.Tensor]`): activation
function to fuse the domains.
**kwargs: additional arguments to pass to
`GlobalWorkspace.load_from_checkpoint`.
`GlobalWorkspace2Domains.load_from_checkpoint`.

Returns:
`GlobalWorkspace`: the pretrained `GlobalWorkspace`.
`GlobalWorkspace2Domains`: the pretrained `GlobalWorkspace2Domains`.

Raises:
`TypeError`: if loaded type is not `GlobalWorkspace`.
Expand Down
4 changes: 2 additions & 2 deletions tests/save_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from utils import DummyDomainModule

from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder
from shimmer.modules.global_workspace import GlobalWorkspace
from shimmer.modules.global_workspace import GlobalWorkspaceFusion

here = Path(__file__).parent

Expand Down Expand Up @@ -54,7 +54,7 @@ def save_gw_ckpt():
workspace_dim=16,
loss_coefs={},
)
gw = GlobalWorkspace(
gw = GlobalWorkspaceFusion(
domains,
gw_encoders,
gw_decoders,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn

from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.global_workspace import GlobalWorkspace
from shimmer.modules.global_workspace import GlobalWorkspaceFusion
from shimmer.modules.losses import BroadcastLossCoefs


Expand Down Expand Up @@ -43,7 +43,7 @@ def test_broadcast_loss():
"contrastives": 0.1,
}

gw_fusion = GlobalWorkspace(
gw_fusion = GlobalWorkspaceFusion(
domain_mods,
gw_encoders,
gw_decoders,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ckpt_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from utils import DummyDomainModule

from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder
from shimmer.modules.global_workspace import GlobalWorkspace
from shimmer.modules.global_workspace import GlobalWorkspaceFusion
from shimmer.utils import MIGRATION_DIR

here = Path(__file__).parent
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_ckpt_migration_gw():
),
}

gw = GlobalWorkspace(
gw = GlobalWorkspaceFusion(
domains,
gw_encoders,
gw_decoders,
Expand Down