@@ -23,13 +23,19 @@ class _Loop:
2323 def __init__ (self , trainer : "pl.Trainer" ) -> None :
2424 self ._restarting = False
2525 self ._loaded_from_state_dict = False
26+ self ._resuming_from_checkpoint = False
2627 self .trainer = trainer
2728
2829 @property
2930 def restarting (self ) -> bool :
3031 """Whether the state of this loop was reloaded and it needs to restart."""
3132 return self ._restarting
3233
34+ @property
35+ def is_resuming (self ) -> bool :
36+ """Whether we're resuming training from a checkpoint."""
37+ return self ._resuming_from_checkpoint
38+
3339 @restarting .setter
3440 def restarting (self , restarting : bool ) -> None :
3541 """Connects this loop's restarting value and its children."""
@@ -87,6 +93,7 @@ def load_state_dict(
8793 v .load_state_dict (state_dict .copy (), prefix + k + "." )
8894 self .restarting = True
8995 self ._loaded_from_state_dict = True
96+ self ._resuming_from_checkpoint = True
9097
9198 def _load_from_state_dict (self , state_dict : dict , prefix : str ) -> None :
9299 for k , v in self .__dict__ .items ():
@@ -102,4 +109,5 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
102109 def on_iteration_done (self ) -> None :
103110 self ._restarting = False
104111 self ._loaded_from_state_dict = False
112+ self ._resuming_from_checkpoint = False
105113 self .reset_restart_stage ()
0 commit comments