From 7161b16730a8e9856b8c4e0210808f394bcdfc69 Mon Sep 17 00:00:00 2001 From: bdvllrs <bdvllrs@gmail.com> Date: Mon, 22 Jan 2024 16:12:05 +0000 Subject: [PATCH] Add pretrained_global_workspace --- shimmer/__init__.py | 12 ++++++---- shimmer/modules/__init__.py | 12 ++++++---- shimmer/modules/global_workspace.py | 36 ++++++++++++++++++++++++++--- shimmer/modules/gw_module.py | 10 ++++---- 4 files changed, 54 insertions(+), 16 deletions(-) diff --git a/shimmer/__init__.py b/shimmer/__init__.py index ce6b4282..494864c0 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -3,9 +3,11 @@ from shimmer.modules.domain import DomainModule from shimmer.modules.global_workspace import (GlobalWorkspace, GlobalWorkspaceBase, - VariationalGlobalWorkspace) -from shimmer.modules.gw_module import (BaseGWInterface, DeterministicGWModule, - GWDecoder, GWEncoder, GWInterface, + SchedulerArgs, + VariationalGlobalWorkspace, + pretrained_global_workspace) +from shimmer.modules.gw_module import (DeterministicGWModule, GWDecoder, + GWEncoder, GWInterface, GWInterfaceBase, GWModule, VariationalGWEncoder, VariationalGWInterface, VariationalGWModule) @@ -19,7 +21,7 @@ "load_structured_config", "ShimmerInfoConfig", "DomainModule", - "BaseGWInterface", + "GWInterfaceBase", "DeterministicGWModule", "GWDecoder", "GWEncoder", @@ -34,4 +36,6 @@ "GlobalWorkspace", "GlobalWorkspaceBase", "VariationalGlobalWorkspace", + "SchedulerArgs", + "pretrained_global_workspace", ] diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index 8d417238..8c669a69 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -1,9 +1,11 @@ from shimmer.modules.domain import DomainModule from shimmer.modules.global_workspace import (GlobalWorkspace, GlobalWorkspaceBase, - VariationalGlobalWorkspace) -from shimmer.modules.gw_module import (BaseGWInterface, DeterministicGWModule, - GWDecoder, GWEncoder, GWInterface, + SchedulerArgs, + VariationalGlobalWorkspace, + pretrained_global_workspace) +from shimmer.modules.gw_module import (DeterministicGWModule, GWDecoder, + GWEncoder, GWInterface, GWInterfaceBase, GWModule, VariationalGWEncoder, VariationalGWInterface, VariationalGWModule) @@ -12,7 +14,7 @@ __all__ = [ "DomainModule", - "BaseGWInterface", + "GWInterfaceBase", "DeterministicGWModule", "GWDecoder", "GWEncoder", @@ -27,4 +29,6 @@ "GlobalWorkspace", "GlobalWorkspaceBase", "VariationalGlobalWorkspace", + "SchedulerArgs", + "pretrained_global_workspace", ] diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 8309cab7..348088d3 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -1,4 +1,5 @@ from collections.abc import Iterable, Mapping +from pathlib import Path from typing import Any, TypedDict, cast import torch @@ -9,7 +10,7 @@ from shimmer.modules.dict_buffer import DictBuffer from shimmer.modules.domain import DomainModule -from shimmer.modules.gw_module import (BaseGWInterface, DeterministicGWModule, +from shimmer.modules.gw_module import (DeterministicGWModule, GWInterfaceBase, GWModule, VariationalGWModule) from shimmer.modules.losses import (DeterministicGWLosses, GWLosses, LatentsT, VariationalGWLosses) @@ -268,7 +269,7 @@ class GlobalWorkspace(GlobalWorkspaceBase): def __init__( self, domain_mods: Mapping[str, DomainModule], - gw_interfaces: Mapping[str, BaseGWInterface], + gw_interfaces: Mapping[str, GWInterfaceBase], workspace_dim: int, loss_coefs: dict[str, torch.Tensor], optim_lr: float = 1e-3, @@ -295,7 +296,7 @@ class VariationalGlobalWorkspace(GlobalWorkspaceBase): def __init__( self, domain_mods: Mapping[str, DomainModule], - gw_interfaces: Mapping[str, BaseGWInterface], + gw_interfaces: Mapping[str, GWInterfaceBase], workspace_dim: int, loss_coefs: dict[str, torch.Tensor], var_contrastive_loss: bool = False, @@ -319,3 +320,32 @@ def __init__( optim_weight_decay, scheduler_args, ) + + +def pretrained_global_workspace( + checkpoint_path: str | Path, + domain_mods: Mapping[str, DomainModule], + gw_interfaces: Mapping[str, GWInterfaceBase], + workspace_dim: int, + loss_coefs: dict[str, torch.Tensor], + var_contrastive_loss: bool = False, + **kwargs, +) -> GlobalWorkspaceBase: + gw_mod = VariationalGWModule(gw_interfaces, workspace_dim) + domain_mods = freeze_domain_modules(domain_mods) + coef_buffers = DictBuffer(loss_coefs) + loss_mod = VariationalGWLosses( + gw_mod, domain_mods, coef_buffers, var_contrastive_loss + ) + + return cast( + GlobalWorkspaceBase, + GlobalWorkspaceBase.load_from_checkpoint( + checkpoint_path, + domain_mods=domain_mods, + gw_mod=gw_mod, + coef_buffers=coef_buffers, + loss_mod=loss_mod, + **kwargs, + ), + ) diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 3137d3ae..e22d4f2e 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -80,7 +80,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return self.layers(x), self.uncertainty_level.expand(x.size(0), -1) -class BaseGWInterface(nn.Module, ABC): +class GWInterfaceBase(nn.Module, ABC): def __init__( self, domain_module: DomainModule, workspace_dim: int ) -> None: @@ -99,12 +99,12 @@ def decode(self, z: torch.Tensor) -> torch.Tensor: class GWModule(nn.Module, ABC): def __init__( - self, gw_interfaces: Mapping[str, BaseGWInterface], workspace_dim: int + self, gw_interfaces: Mapping[str, GWInterfaceBase], workspace_dim: int ) -> None: super().__init__() # casting for LSP autocompletion self.gw_interfaces = cast( - dict[str, BaseGWInterface], nn.ModuleDict(gw_interfaces) + dict[str, GWInterfaceBase], nn.ModuleDict(gw_interfaces) ) self.workspace_dim = workspace_dim @@ -234,7 +234,7 @@ def cycle( ... -class GWInterface(BaseGWInterface): +class GWInterface(GWInterfaceBase): def __init__( self, domain_module: DomainModule, @@ -309,7 +309,7 @@ def cycle( } -class VariationalGWInterface(BaseGWInterface): +class VariationalGWInterface(GWInterfaceBase): def __init__( self, domain_module: DomainModule,