@@ -32,7 +32,7 @@ class GlobalWorkspaceBase(LightningModule):
32
32
def __init__ (
33
33
self ,
34
34
gw_mod : GWModule ,
35
- domain_mods : dict [str , DomainModule ],
35
+ domain_mods : Mapping [str , DomainModule ],
36
36
coef_buffers : DictBuffer ,
37
37
loss_mod : GWLosses ,
38
38
optim_lr : float = 1e-3 ,
@@ -271,7 +271,7 @@ def __init__(
271
271
domain_mods : Mapping [str , DomainModule ],
272
272
gw_interfaces : Mapping [str , GWInterfaceBase ],
273
273
workspace_dim : int ,
274
- loss_coefs : dict [str , torch .Tensor ],
274
+ loss_coefs : Mapping [str , torch .Tensor ],
275
275
optim_lr : float = 1e-3 ,
276
276
optim_weight_decay : float = 0.0 ,
277
277
scheduler_args : SchedulerArgs | None = None ,
@@ -298,7 +298,7 @@ def __init__(
298
298
domain_mods : Mapping [str , DomainModule ],
299
299
gw_interfaces : Mapping [str , GWInterfaceBase ],
300
300
workspace_dim : int ,
301
- loss_coefs : dict [str , torch .Tensor ],
301
+ loss_coefs : Mapping [str , torch .Tensor ],
302
302
var_contrastive_loss : bool = False ,
303
303
optim_lr : float = 1e-3 ,
304
304
optim_weight_decay : float = 0.0 ,
@@ -327,7 +327,7 @@ def pretrained_global_workspace(
327
327
domain_mods : Mapping [str , DomainModule ],
328
328
gw_interfaces : Mapping [str , GWInterfaceBase ],
329
329
workspace_dim : int ,
330
- loss_coefs : dict [str , torch .Tensor ],
330
+ loss_coefs : Mapping [str , torch .Tensor ],
331
331
var_contrastive_loss : bool = False ,
332
332
** kwargs ,
333
333
) -> GlobalWorkspaceBase :
0 commit comments