From 7161b16730a8e9856b8c4e0210808f394bcdfc69 Mon Sep 17 00:00:00 2001
From: bdvllrs <bdvllrs@gmail.com>
Date: Mon, 22 Jan 2024 16:12:05 +0000
Subject: [PATCH] Add pretrained_global_workspace

---
 shimmer/__init__.py                 | 12 ++++++----
 shimmer/modules/__init__.py         | 12 ++++++----
 shimmer/modules/global_workspace.py | 36 ++++++++++++++++++++++++++---
 shimmer/modules/gw_module.py        | 10 ++++----
 4 files changed, 54 insertions(+), 16 deletions(-)

diff --git a/shimmer/__init__.py b/shimmer/__init__.py
index ce6b4282..494864c0 100644
--- a/shimmer/__init__.py
+++ b/shimmer/__init__.py
@@ -3,9 +3,11 @@
 from shimmer.modules.domain import DomainModule
 from shimmer.modules.global_workspace import (GlobalWorkspace,
                                               GlobalWorkspaceBase,
-                                              VariationalGlobalWorkspace)
-from shimmer.modules.gw_module import (BaseGWInterface, DeterministicGWModule,
-                                       GWDecoder, GWEncoder, GWInterface,
+                                              SchedulerArgs,
+                                              VariationalGlobalWorkspace,
+                                              pretrained_global_workspace)
+from shimmer.modules.gw_module import (DeterministicGWModule, GWDecoder,
+                                       GWEncoder, GWInterface, GWInterfaceBase,
                                        GWModule, VariationalGWEncoder,
                                        VariationalGWInterface,
                                        VariationalGWModule)
@@ -19,7 +21,7 @@
     "load_structured_config",
     "ShimmerInfoConfig",
     "DomainModule",
-    "BaseGWInterface",
+    "GWInterfaceBase",
     "DeterministicGWModule",
     "GWDecoder",
     "GWEncoder",
@@ -34,4 +36,6 @@
     "GlobalWorkspace",
     "GlobalWorkspaceBase",
     "VariationalGlobalWorkspace",
+    "SchedulerArgs",
+    "pretrained_global_workspace",
 ]
diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py
index 8d417238..8c669a69 100644
--- a/shimmer/modules/__init__.py
+++ b/shimmer/modules/__init__.py
@@ -1,9 +1,11 @@
 from shimmer.modules.domain import DomainModule
 from shimmer.modules.global_workspace import (GlobalWorkspace,
                                               GlobalWorkspaceBase,
-                                              VariationalGlobalWorkspace)
-from shimmer.modules.gw_module import (BaseGWInterface, DeterministicGWModule,
-                                       GWDecoder, GWEncoder, GWInterface,
+                                              SchedulerArgs,
+                                              VariationalGlobalWorkspace,
+                                              pretrained_global_workspace)
+from shimmer.modules.gw_module import (DeterministicGWModule, GWDecoder,
+                                       GWEncoder, GWInterface, GWInterfaceBase,
                                        GWModule, VariationalGWEncoder,
                                        VariationalGWInterface,
                                        VariationalGWModule)
@@ -12,7 +14,7 @@
 
 __all__ = [
     "DomainModule",
-    "BaseGWInterface",
+    "GWInterfaceBase",
     "DeterministicGWModule",
     "GWDecoder",
     "GWEncoder",
@@ -27,4 +29,6 @@
     "GlobalWorkspace",
     "GlobalWorkspaceBase",
     "VariationalGlobalWorkspace",
+    "SchedulerArgs",
+    "pretrained_global_workspace",
 ]
diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py
index 8309cab7..348088d3 100644
--- a/shimmer/modules/global_workspace.py
+++ b/shimmer/modules/global_workspace.py
@@ -1,4 +1,5 @@
 from collections.abc import Iterable, Mapping
+from pathlib import Path
 from typing import Any, TypedDict, cast
 
 import torch
@@ -9,7 +10,7 @@
 
 from shimmer.modules.dict_buffer import DictBuffer
 from shimmer.modules.domain import DomainModule
-from shimmer.modules.gw_module import (BaseGWInterface, DeterministicGWModule,
+from shimmer.modules.gw_module import (DeterministicGWModule, GWInterfaceBase,
                                        GWModule, VariationalGWModule)
 from shimmer.modules.losses import (DeterministicGWLosses, GWLosses, LatentsT,
                                     VariationalGWLosses)
@@ -268,7 +269,7 @@ class GlobalWorkspace(GlobalWorkspaceBase):
     def __init__(
         self,
         domain_mods: Mapping[str, DomainModule],
-        gw_interfaces: Mapping[str, BaseGWInterface],
+        gw_interfaces: Mapping[str, GWInterfaceBase],
         workspace_dim: int,
         loss_coefs: dict[str, torch.Tensor],
         optim_lr: float = 1e-3,
@@ -295,7 +296,7 @@ class VariationalGlobalWorkspace(GlobalWorkspaceBase):
     def __init__(
         self,
         domain_mods: Mapping[str, DomainModule],
-        gw_interfaces: Mapping[str, BaseGWInterface],
+        gw_interfaces: Mapping[str, GWInterfaceBase],
         workspace_dim: int,
         loss_coefs: dict[str, torch.Tensor],
         var_contrastive_loss: bool = False,
@@ -319,3 +320,32 @@ def __init__(
             optim_weight_decay,
             scheduler_args,
         )
+
+
+def pretrained_global_workspace(
+    checkpoint_path: str | Path,
+    domain_mods: Mapping[str, DomainModule],
+    gw_interfaces: Mapping[str, GWInterfaceBase],
+    workspace_dim: int,
+    loss_coefs: dict[str, torch.Tensor],
+    var_contrastive_loss: bool = False,
+    **kwargs,
+) -> GlobalWorkspaceBase:
+    gw_mod = VariationalGWModule(gw_interfaces, workspace_dim)
+    domain_mods = freeze_domain_modules(domain_mods)
+    coef_buffers = DictBuffer(loss_coefs)
+    loss_mod = VariationalGWLosses(
+        gw_mod, domain_mods, coef_buffers, var_contrastive_loss
+    )
+
+    return cast(
+        GlobalWorkspaceBase,
+        GlobalWorkspaceBase.load_from_checkpoint(
+            checkpoint_path,
+            domain_mods=domain_mods,
+            gw_mod=gw_mod,
+            coef_buffers=coef_buffers,
+            loss_mod=loss_mod,
+            **kwargs,
+        ),
+    )
diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py
index 3137d3ae..e22d4f2e 100644
--- a/shimmer/modules/gw_module.py
+++ b/shimmer/modules/gw_module.py
@@ -80,7 +80,7 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
         return self.layers(x), self.uncertainty_level.expand(x.size(0), -1)
 
 
-class BaseGWInterface(nn.Module, ABC):
+class GWInterfaceBase(nn.Module, ABC):
     def __init__(
         self, domain_module: DomainModule, workspace_dim: int
     ) -> None:
@@ -99,12 +99,12 @@ def decode(self, z: torch.Tensor) -> torch.Tensor:
 
 class GWModule(nn.Module, ABC):
     def __init__(
-        self, gw_interfaces: Mapping[str, BaseGWInterface], workspace_dim: int
+        self, gw_interfaces: Mapping[str, GWInterfaceBase], workspace_dim: int
     ) -> None:
         super().__init__()
         # casting for LSP autocompletion
         self.gw_interfaces = cast(
-            dict[str, BaseGWInterface], nn.ModuleDict(gw_interfaces)
+            dict[str, GWInterfaceBase], nn.ModuleDict(gw_interfaces)
         )
         self.workspace_dim = workspace_dim
 
@@ -234,7 +234,7 @@ def cycle(
         ...
 
 
-class GWInterface(BaseGWInterface):
+class GWInterface(GWInterfaceBase):
     def __init__(
         self,
         domain_module: DomainModule,
@@ -309,7 +309,7 @@ def cycle(
         }
 
 
-class VariationalGWInterface(BaseGWInterface):
+class VariationalGWInterface(GWInterfaceBase):
     def __init__(
         self,
         domain_module: DomainModule,