@@ -310,7 +310,7 @@ def execute_something():
310
310
except ValueError :
311
311
_check_signature (handler , "handler" , * (event_args + args ), ** kwargs )
312
312
self ._event_handlers [event_name ].append ((handler , args , kwargs ))
313
- self .logger .debug (f"added handler for event { event_name } " )
313
+ self .logger .debug (f"Added handler for event { event_name } " )
314
314
315
315
return RemovableEventHandle (event_name , handler , self )
316
316
@@ -406,7 +406,7 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
406
406
**event_kwargs: optional keyword args to be passed to all handlers.
407
407
408
408
"""
409
- self .logger .debug (f"firing handlers for event { event_name } " )
409
+ self .logger .debug (f"{ self . state . epoch } | { self . state . iteration } , Firing handlers for event { event_name } " )
410
410
self .last_event_name = event_name
411
411
for func , args , kwargs in self ._event_handlers [event_name ]:
412
412
kwargs .update (event_kwargs )
@@ -720,6 +720,11 @@ def switch_batch(engine):
720
720
if self .state .epoch_length is None and data is None :
721
721
raise ValueError ("epoch_length should be provided if data is None" )
722
722
723
+ if self .should_terminate :
724
+ # If engine was terminated and now is resuming from terminated state
725
+ # we need to initialize iter_counter as 0
726
+ self ._init_iter .append (0 )
727
+
723
728
self .state .dataloader = data
724
729
return self ._internal_run ()
725
730
@@ -750,12 +755,13 @@ def _setup_dataloader_iter(self) -> None:
750
755
751
756
def _setup_engine (self ) -> None :
752
757
self ._setup_dataloader_iter ()
753
- iteration = self .state .iteration
754
758
755
- # Below we define initial counter value for _run_once_on_dataset to measure a single epoch
756
- if self .state .epoch_length is not None :
757
- iteration %= self .state .epoch_length
758
- self ._init_iter .append (iteration )
759
+ if len (self ._init_iter ) == 0 :
760
+ iteration = self .state .iteration
761
+ # Below we define initial counter value for _run_once_on_dataset to measure a single epoch
762
+ if self .state .epoch_length is not None :
763
+ iteration %= self .state .epoch_length
764
+ self ._init_iter .append (iteration )
759
765
760
766
def _internal_run (self ) -> State :
761
767
self .should_terminate = self .should_terminate_single_epoch = False
@@ -826,6 +832,11 @@ def _run_once_on_dataset(self) -> float:
826
832
start_time = time .time ()
827
833
828
834
# We need to setup iter_counter > 0 if we resume from an iteration
835
+ if len (self ._init_iter ) > 1 :
836
+ raise RuntimeError (
837
+ "Internal error, len(self._init_iter) should 0 or 1, "
838
+ f"but got: { len (self ._init_iter )} , { self ._init_iter } "
839
+ )
829
840
iter_counter = self ._init_iter .pop () if len (self ._init_iter ) > 0 else 0
830
841
should_exit = False
831
842
try :
0 commit comments