Skip to content

Commit 559c991

Browse files
committed
replace types dict to Mapping
1 parent 7161b16 commit 559c991

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

shimmer/modules/dict_buffer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Generator, Iterator
1+
from collections.abc import Generator, Iterator, Mapping
22
from typing import Any
33

44
import torch
@@ -7,7 +7,7 @@
77

88
class DictBuffer(nn.Module):
99
def __init__(
10-
self, buffer_dict: dict[str, torch.Tensor], persistent: bool = True
10+
self, buffer_dict: Mapping[str, torch.Tensor], persistent: bool = True
1111
):
1212
super().__init__()
1313

shimmer/modules/global_workspace.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class GlobalWorkspaceBase(LightningModule):
3232
def __init__(
3333
self,
3434
gw_mod: GWModule,
35-
domain_mods: dict[str, DomainModule],
35+
domain_mods: Mapping[str, DomainModule],
3636
coef_buffers: DictBuffer,
3737
loss_mod: GWLosses,
3838
optim_lr: float = 1e-3,
@@ -271,7 +271,7 @@ def __init__(
271271
domain_mods: Mapping[str, DomainModule],
272272
gw_interfaces: Mapping[str, GWInterfaceBase],
273273
workspace_dim: int,
274-
loss_coefs: dict[str, torch.Tensor],
274+
loss_coefs: Mapping[str, torch.Tensor],
275275
optim_lr: float = 1e-3,
276276
optim_weight_decay: float = 0.0,
277277
scheduler_args: SchedulerArgs | None = None,
@@ -298,7 +298,7 @@ def __init__(
298298
domain_mods: Mapping[str, DomainModule],
299299
gw_interfaces: Mapping[str, GWInterfaceBase],
300300
workspace_dim: int,
301-
loss_coefs: dict[str, torch.Tensor],
301+
loss_coefs: Mapping[str, torch.Tensor],
302302
var_contrastive_loss: bool = False,
303303
optim_lr: float = 1e-3,
304304
optim_weight_decay: float = 0.0,
@@ -327,7 +327,7 @@ def pretrained_global_workspace(
327327
domain_mods: Mapping[str, DomainModule],
328328
gw_interfaces: Mapping[str, GWInterfaceBase],
329329
workspace_dim: int,
330-
loss_coefs: dict[str, torch.Tensor],
330+
loss_coefs: Mapping[str, torch.Tensor],
331331
var_contrastive_loss: bool = False,
332332
**kwargs,
333333
) -> GlobalWorkspaceBase:

0 commit comments

Comments
 (0)