Skip to content

Commit 55289de

Browse files
committed
Rename gw_latent_dim to workspace_dim.
Add workspace_dim attribute in Global Workspace. Fixes #1
1 parent 4790591 commit 55289de

File tree

4 files changed

+29
-25
lines changed

4 files changed

+29
-25
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ from shimmer import GWInterface
6868
my_domain = MyDomain()
6969
my_domain_gw_interface = GWInterface(
7070
my_domain,
71-
gw_latent_dim=12, # latent dim of the global workspace
71+
workspace_dim=12, # latent dim of the global workspace
7272
encoder_hidden_dim=32, # hidden dimension for the GW encoder
7373
encoder_n_layers=3, # n layers to use for the GW encoder
7474
decoder_hidden_dim=32, # hidden dimension for the GW decoder

shimmer/modules/global_workspace.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from shimmer.modules.dict_buffer import DictBuffer
1111
from shimmer.modules.domain import DomainModule
12-
from shimmer.modules.gw_module import (DeterministicGWModule, GWInterface,
12+
from shimmer.modules.gw_module import (BaseGWInterface, DeterministicGWModule,
1313
GWModule, VariationalGWModule)
1414
from shimmer.modules.losses import (DeterministicGWLosses, GWLosses, LatentsT,
1515
VariationalGWLosses)
@@ -60,6 +60,10 @@ def __init__(
6060
if scheduler_args is not None:
6161
self.scheduler_args.update(scheduler_args)
6262

63+
@property
64+
def workspace_dim(self):
65+
return self.gw_mod.workspace_dim
66+
6367
def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
6468
return self.gw_mod.encode(x)
6569

@@ -264,14 +268,14 @@ class GlobalWorkspace(GlobalWorkspaceBase):
264268
def __init__(
265269
self,
266270
domain_mods: Mapping[str, DomainModule],
267-
gw_interfaces: Mapping[str, GWInterface],
268-
gw_latent_dim: int,
271+
gw_interfaces: Mapping[str, BaseGWInterface],
272+
workspace_dim: int,
269273
loss_coefs: dict[str, torch.Tensor],
270274
optim_lr: float = 1e-3,
271275
optim_weight_decay: float = 0.0,
272276
scheduler_args: SchedulerArgs | None = None,
273277
) -> None:
274-
gw_mod = DeterministicGWModule(gw_interfaces, gw_latent_dim)
278+
gw_mod = DeterministicGWModule(gw_interfaces, workspace_dim)
275279
domain_mods = freeze_domain_modules(domain_mods)
276280
coef_buffers = DictBuffer(loss_coefs)
277281
loss_mod = DeterministicGWLosses(gw_mod, domain_mods, coef_buffers)
@@ -291,15 +295,15 @@ class VariationalGlobalWorkspace(GlobalWorkspaceBase):
291295
def __init__(
292296
self,
293297
domain_mods: Mapping[str, DomainModule],
294-
gw_interfaces: Mapping[str, GWInterface],
295-
gw_latent_dim: int,
298+
gw_interfaces: Mapping[str, BaseGWInterface],
299+
workspace_dim: int,
296300
loss_coefs: dict[str, torch.Tensor],
297301
var_contrastive_loss: bool = False,
298302
optim_lr: float = 1e-3,
299303
optim_weight_decay: float = 0.0,
300304
scheduler_args: SchedulerArgs | None = None,
301305
) -> None:
302-
gw_mod = VariationalGWModule(gw_interfaces, gw_latent_dim)
306+
gw_mod = VariationalGWModule(gw_interfaces, workspace_dim)
303307
domain_mods = freeze_domain_modules(domain_mods)
304308
coef_buffers = DictBuffer(loss_coefs)
305309
loss_mod = VariationalGWLosses(

shimmer/modules/gw_module.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
8282

8383
class BaseGWInterface(nn.Module, ABC):
8484
def __init__(
85-
self, domain_module: DomainModule, gw_latent_dim: int
85+
self, domain_module: DomainModule, workspace_dim: int
8686
) -> None:
8787
super().__init__()
8888
self.domain_module = domain_module
89-
self.gw_latent_dim = gw_latent_dim
89+
self.workspace_dim = workspace_dim
9090

9191
@abstractmethod
9292
def encode(self, x: torch.Tensor) -> torch.Tensor:
@@ -99,14 +99,14 @@ def decode(self, z: torch.Tensor) -> torch.Tensor:
9999

100100
class GWModule(nn.Module, ABC):
101101
def __init__(
102-
self, gw_interfaces: Mapping[str, BaseGWInterface], gw_latent_dim: int
102+
self, gw_interfaces: Mapping[str, BaseGWInterface], workspace_dim: int
103103
) -> None:
104104
super().__init__()
105105
# casting for LSP autocompletion
106106
self.gw_interfaces = cast(
107107
dict[str, BaseGWInterface], nn.ModuleDict(gw_interfaces)
108108
)
109-
self.latent_dim = gw_latent_dim
109+
self.workspace_dim = workspace_dim
110110

111111
def on_before_gw_encode_dcy(
112112
self, x: Mapping[str, torch.Tensor]
@@ -238,22 +238,22 @@ class GWInterface(BaseGWInterface):
238238
def __init__(
239239
self,
240240
domain_module: DomainModule,
241-
gw_latent_dim: int,
241+
workspace_dim: int,
242242
encoder_hidden_dim: int,
243243
encoder_n_layers: int,
244244
decoder_hidden_dim: int,
245245
decoder_n_layers: int,
246246
) -> None:
247-
super().__init__(domain_module, gw_latent_dim)
247+
super().__init__(domain_module, workspace_dim)
248248

249249
self.encoder = GWEncoder(
250250
domain_module.latent_dim,
251251
encoder_hidden_dim,
252-
gw_latent_dim,
252+
workspace_dim,
253253
encoder_n_layers,
254254
)
255255
self.decoder = GWDecoder(
256-
gw_latent_dim,
256+
workspace_dim,
257257
decoder_hidden_dim,
258258
domain_module.latent_dim,
259259
decoder_n_layers,
@@ -313,22 +313,22 @@ class VariationalGWInterface(BaseGWInterface):
313313
def __init__(
314314
self,
315315
domain_module: DomainModule,
316-
gw_latent_dim: int,
316+
workspace_dim: int,
317317
encoder_hidden_dim: int,
318318
encoder_n_layers: int,
319319
decoder_hidden_dim: int,
320320
decoder_n_layers: int,
321321
) -> None:
322-
super().__init__(domain_module, gw_latent_dim)
322+
super().__init__(domain_module, workspace_dim)
323323

324324
self.encoder = VariationalGWEncoder(
325325
domain_module.latent_dim,
326326
encoder_hidden_dim,
327-
gw_latent_dim,
327+
workspace_dim,
328328
encoder_n_layers,
329329
)
330330
self.decoder = GWDecoder(
331-
gw_latent_dim,
331+
workspace_dim,
332332
decoder_hidden_dim,
333333
domain_module.latent_dim,
334334
decoder_n_layers,

tests/test_training.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,28 @@ def test_training():
1616
"a": DummyDomainModule(latent_dim=128),
1717
}
1818

19-
gw_latent_dim = 16
19+
workspace_dim = 16
2020

2121
gw_interfaces = {
2222
"v": GWInterface(
2323
domains["v"],
24-
gw_latent_dim=gw_latent_dim,
24+
workspace_dim=workspace_dim,
2525
encoder_hidden_dim=64,
2626
encoder_n_layers=1,
2727
decoder_hidden_dim=64,
2828
decoder_n_layers=1,
2929
),
3030
"t": GWInterface(
3131
domains["t"],
32-
gw_latent_dim=gw_latent_dim,
32+
workspace_dim=workspace_dim,
3333
encoder_hidden_dim=64,
3434
encoder_n_layers=1,
3535
decoder_hidden_dim=64,
3636
decoder_n_layers=1,
3737
),
3838
"a": GWInterface(
3939
domains["a"],
40-
gw_latent_dim=gw_latent_dim,
40+
workspace_dim=workspace_dim,
4141
encoder_hidden_dim=64,
4242
encoder_n_layers=1,
4343
decoder_hidden_dim=64,
@@ -48,7 +48,7 @@ def test_training():
4848
gw = GlobalWorkspace(
4949
domains,
5050
gw_interfaces,
51-
gw_latent_dim=16,
51+
workspace_dim=16,
5252
loss_coefs={},
5353
)
5454

0 commit comments

Comments
 (0)