Skip to content

Commit 7a22542

Browse files
committed
Set default value for dataloader_idx
1 parent 1ef896b commit 7a22542

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

shimmer/modules/global_workspace.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def generic_step(
187187
return losses["loss"]
188188

189189
def validation_step(
190-
self, data: Mapping[str, Any], _, dataloader_idx: int
190+
self, data: Mapping[str, Any], _, dataloader_idx: int = 0
191191
) -> torch.Tensor:
192192
batch = {frozenset(data.keys()): data}
193193
for domain in data.keys():
@@ -197,7 +197,7 @@ def validation_step(
197197
return self.generic_step(batch, mode="val/ood")
198198

199199
def test_step(
200-
self, data: Mapping[str, Any], _, dataloader_idx: int
200+
self, data: Mapping[str, Any], _, dataloader_idx: int = 0
201201
) -> torch.Tensor:
202202
batch = {frozenset(data.keys()): data}
203203
for domain in data.keys():

0 commit comments

Comments
 (0)