5
5
6
6
import torch
7
7
from lightning .pytorch import LightningModule
8
- from lightning .pytorch .utilities .types import (
9
- OptimizerLRScheduler ,
10
- )
8
+ from lightning .pytorch .utilities .types import STEP_OUTPUT , OptimizerLRScheduler
11
9
from torch .nn import Module , ModuleDict
12
10
from torch .optim .adamw import AdamW
13
11
from torch .optim .lr_scheduler import LRScheduler , OneCycleLR
@@ -486,7 +484,7 @@ def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroup
486
484
for domains , latents in latents_domain .items ()
487
485
}
488
486
489
- def generic_step (self , batch : RawDomainGroupsT , mode : ModelModeT ) -> torch . Tensor :
487
+ def generic_step (self , batch : RawDomainGroupsT , mode : ModelModeT ) -> STEP_OUTPUT :
490
488
"""
491
489
The generic step used in `training_step`, `validation_step` and
492
490
`test_step`.
@@ -515,7 +513,7 @@ def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tenso
515
513
516
514
def validation_step ( # type: ignore
517
515
self , data : RawDomainGroupT , batch_idx : int , dataloader_idx : int = 0
518
- ) -> torch . Tensor :
516
+ ) -> STEP_OUTPUT :
519
517
"""Validation step used by lightning"""
520
518
521
519
batch = {frozenset (data .keys ()): data }
@@ -527,7 +525,7 @@ def validation_step( # type: ignore
527
525
528
526
def test_step ( # type: ignore
529
527
self , data : Mapping [str , Any ], batch_idx : int , dataloader_idx : int = 0
530
- ) -> torch . Tensor :
528
+ ) -> STEP_OUTPUT :
531
529
"""Test step used by lightning"""
532
530
533
531
batch = {frozenset (data .keys ()): data }
@@ -539,7 +537,7 @@ def test_step( # type: ignore
539
537
540
538
def training_step ( # type: ignore
541
539
self , batch : Mapping [frozenset [str ], Mapping [str , Any ]], batch_idx : int
542
- ) -> torch . Tensor :
540
+ ) -> STEP_OUTPUT :
543
541
"""Training step used by lightning"""
544
542
545
543
return self .generic_step (batch , mode = "train" )
0 commit comments