Skip to content

Commit 7d90ee9

Browse files
authored
more generic outputs for training_step, validation_step and test_step (#148)
1 parent 2c9b60d commit 7d90ee9

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

shimmer/modules/global_workspace.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66
import torch
77
from lightning.pytorch import LightningModule
8-
from lightning.pytorch.utilities.types import (
9-
OptimizerLRScheduler,
10-
)
8+
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
119
from torch.nn import Module, ModuleDict
1210
from torch.optim.adamw import AdamW
1311
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
@@ -486,7 +484,7 @@ def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroup
486484
for domains, latents in latents_domain.items()
487485
}
488486

489-
def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor:
487+
def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> STEP_OUTPUT:
490488
"""
491489
The generic step used in `training_step`, `validation_step` and
492490
`test_step`.
@@ -515,7 +513,7 @@ def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tenso
515513

516514
def validation_step( # type: ignore
517515
self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0
518-
) -> torch.Tensor:
516+
) -> STEP_OUTPUT:
519517
"""Validation step used by lightning"""
520518

521519
batch = {frozenset(data.keys()): data}
@@ -527,7 +525,7 @@ def validation_step( # type: ignore
527525

528526
def test_step( # type: ignore
529527
self, data: Mapping[str, Any], batch_idx: int, dataloader_idx: int = 0
530-
) -> torch.Tensor:
528+
) -> STEP_OUTPUT:
531529
"""Test step used by lightning"""
532530

533531
batch = {frozenset(data.keys()): data}
@@ -539,7 +537,7 @@ def test_step( # type: ignore
539537

540538
def training_step( # type: ignore
541539
self, batch: Mapping[frozenset[str], Mapping[str, Any]], batch_idx: int
542-
) -> torch.Tensor:
540+
) -> STEP_OUTPUT:
543541
"""Training step used by lightning"""
544542

545543
return self.generic_step(batch, mode="train")

0 commit comments

Comments
 (0)