Skip to content

Commit

Permalink
Add pretrained_global_workspace
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Jan 22, 2024
1 parent ac0cd34 commit 7161b16
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 16 deletions.
12 changes: 8 additions & 4 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from shimmer.modules.domain import DomainModule
from shimmer.modules.global_workspace import (GlobalWorkspace,
GlobalWorkspaceBase,
VariationalGlobalWorkspace)
from shimmer.modules.gw_module import (BaseGWInterface, DeterministicGWModule,
GWDecoder, GWEncoder, GWInterface,
SchedulerArgs,
VariationalGlobalWorkspace,
pretrained_global_workspace)
from shimmer.modules.gw_module import (DeterministicGWModule, GWDecoder,
GWEncoder, GWInterface, GWInterfaceBase,
GWModule, VariationalGWEncoder,
VariationalGWInterface,
VariationalGWModule)
Expand All @@ -19,7 +21,7 @@
"load_structured_config",
"ShimmerInfoConfig",
"DomainModule",
"BaseGWInterface",
"GWInterfaceBase",
"DeterministicGWModule",
"GWDecoder",
"GWEncoder",
Expand All @@ -34,4 +36,6 @@
"GlobalWorkspace",
"GlobalWorkspaceBase",
"VariationalGlobalWorkspace",
"SchedulerArgs",
"pretrained_global_workspace",
]
12 changes: 8 additions & 4 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from shimmer.modules.domain import DomainModule
from shimmer.modules.global_workspace import (GlobalWorkspace,
GlobalWorkspaceBase,
VariationalGlobalWorkspace)
from shimmer.modules.gw_module import (BaseGWInterface, DeterministicGWModule,
GWDecoder, GWEncoder, GWInterface,
SchedulerArgs,
VariationalGlobalWorkspace,
pretrained_global_workspace)
from shimmer.modules.gw_module import (DeterministicGWModule, GWDecoder,
GWEncoder, GWInterface, GWInterfaceBase,
GWModule, VariationalGWEncoder,
VariationalGWInterface,
VariationalGWModule)
Expand All @@ -12,7 +14,7 @@

__all__ = [
"DomainModule",
"BaseGWInterface",
"GWInterfaceBase",
"DeterministicGWModule",
"GWDecoder",
"GWEncoder",
Expand All @@ -27,4 +29,6 @@
"GlobalWorkspace",
"GlobalWorkspaceBase",
"VariationalGlobalWorkspace",
"SchedulerArgs",
"pretrained_global_workspace",
]
36 changes: 33 additions & 3 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterable, Mapping
from pathlib import Path
from typing import Any, TypedDict, cast

import torch
Expand All @@ -9,7 +10,7 @@

from shimmer.modules.dict_buffer import DictBuffer
from shimmer.modules.domain import DomainModule
from shimmer.modules.gw_module import (BaseGWInterface, DeterministicGWModule,
from shimmer.modules.gw_module import (DeterministicGWModule, GWInterfaceBase,
GWModule, VariationalGWModule)
from shimmer.modules.losses import (DeterministicGWLosses, GWLosses, LatentsT,
VariationalGWLosses)
Expand Down Expand Up @@ -268,7 +269,7 @@ class GlobalWorkspace(GlobalWorkspaceBase):
def __init__(
self,
domain_mods: Mapping[str, DomainModule],
gw_interfaces: Mapping[str, BaseGWInterface],
gw_interfaces: Mapping[str, GWInterfaceBase],
workspace_dim: int,
loss_coefs: dict[str, torch.Tensor],
optim_lr: float = 1e-3,
Expand All @@ -295,7 +296,7 @@ class VariationalGlobalWorkspace(GlobalWorkspaceBase):
def __init__(
self,
domain_mods: Mapping[str, DomainModule],
gw_interfaces: Mapping[str, BaseGWInterface],
gw_interfaces: Mapping[str, GWInterfaceBase],
workspace_dim: int,
loss_coefs: dict[str, torch.Tensor],
var_contrastive_loss: bool = False,
Expand All @@ -319,3 +320,32 @@ def __init__(
optim_weight_decay,
scheduler_args,
)


def pretrained_global_workspace(
checkpoint_path: str | Path,
domain_mods: Mapping[str, DomainModule],
gw_interfaces: Mapping[str, GWInterfaceBase],
workspace_dim: int,
loss_coefs: dict[str, torch.Tensor],
var_contrastive_loss: bool = False,
**kwargs,
) -> GlobalWorkspaceBase:
gw_mod = VariationalGWModule(gw_interfaces, workspace_dim)
domain_mods = freeze_domain_modules(domain_mods)
coef_buffers = DictBuffer(loss_coefs)
loss_mod = VariationalGWLosses(
gw_mod, domain_mods, coef_buffers, var_contrastive_loss
)

return cast(
GlobalWorkspaceBase,
GlobalWorkspaceBase.load_from_checkpoint(
checkpoint_path,
domain_mods=domain_mods,
gw_mod=gw_mod,
coef_buffers=coef_buffers,
loss_mod=loss_mod,
**kwargs,
),
)
10 changes: 5 additions & 5 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return self.layers(x), self.uncertainty_level.expand(x.size(0), -1)


class BaseGWInterface(nn.Module, ABC):
class GWInterfaceBase(nn.Module, ABC):
def __init__(
self, domain_module: DomainModule, workspace_dim: int
) -> None:
Expand All @@ -99,12 +99,12 @@ def decode(self, z: torch.Tensor) -> torch.Tensor:

class GWModule(nn.Module, ABC):
def __init__(
self, gw_interfaces: Mapping[str, BaseGWInterface], workspace_dim: int
self, gw_interfaces: Mapping[str, GWInterfaceBase], workspace_dim: int
) -> None:
super().__init__()
# casting for LSP autocompletion
self.gw_interfaces = cast(
dict[str, BaseGWInterface], nn.ModuleDict(gw_interfaces)
dict[str, GWInterfaceBase], nn.ModuleDict(gw_interfaces)
)
self.workspace_dim = workspace_dim

Expand Down Expand Up @@ -234,7 +234,7 @@ def cycle(
...


class GWInterface(BaseGWInterface):
class GWInterface(GWInterfaceBase):
def __init__(
self,
domain_module: DomainModule,
Expand Down Expand Up @@ -309,7 +309,7 @@ def cycle(
}


class VariationalGWInterface(BaseGWInterface):
class VariationalGWInterface(GWInterfaceBase):
def __init__(
self,
domain_module: DomainModule,
Expand Down

0 comments on commit 7161b16

Please sign in to comment.