Skip to content

Commit 26f7cec

Browse files
authored
Fixing engine terminate behaviour when resumed (#2678)
1 parent 942af82 commit 26f7cec

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

ignite/engine/engine.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def execute_something():
310310
except ValueError:
311311
_check_signature(handler, "handler", *(event_args + args), **kwargs)
312312
self._event_handlers[event_name].append((handler, args, kwargs))
313-
self.logger.debug(f"added handler for event {event_name}")
313+
self.logger.debug(f"Added handler for event {event_name}")
314314

315315
return RemovableEventHandle(event_name, handler, self)
316316

@@ -406,7 +406,7 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
406406
**event_kwargs: optional keyword args to be passed to all handlers.
407407
408408
"""
409-
self.logger.debug(f"firing handlers for event {event_name}")
409+
self.logger.debug(f"{self.state.epoch} | {self.state.iteration}, Firing handlers for event {event_name}")
410410
self.last_event_name = event_name
411411
for func, args, kwargs in self._event_handlers[event_name]:
412412
kwargs.update(event_kwargs)
@@ -720,6 +720,11 @@ def switch_batch(engine):
720720
if self.state.epoch_length is None and data is None:
721721
raise ValueError("epoch_length should be provided if data is None")
722722

723+
if self.should_terminate:
724+
# If engine was terminated and now is resuming from terminated state
725+
# we need to initialize iter_counter as 0
726+
self._init_iter.append(0)
727+
723728
self.state.dataloader = data
724729
return self._internal_run()
725730

@@ -750,12 +755,13 @@ def _setup_dataloader_iter(self) -> None:
750755

751756
def _setup_engine(self) -> None:
752757
self._setup_dataloader_iter()
753-
iteration = self.state.iteration
754758

755-
# Below we define initial counter value for _run_once_on_dataset to measure a single epoch
756-
if self.state.epoch_length is not None:
757-
iteration %= self.state.epoch_length
758-
self._init_iter.append(iteration)
759+
if len(self._init_iter) == 0:
760+
iteration = self.state.iteration
761+
# Below we define initial counter value for _run_once_on_dataset to measure a single epoch
762+
if self.state.epoch_length is not None:
763+
iteration %= self.state.epoch_length
764+
self._init_iter.append(iteration)
759765

760766
def _internal_run(self) -> State:
761767
self.should_terminate = self.should_terminate_single_epoch = False
@@ -826,6 +832,11 @@ def _run_once_on_dataset(self) -> float:
826832
start_time = time.time()
827833

828834
# We need to setup iter_counter > 0 if we resume from an iteration
835+
if len(self._init_iter) > 1:
836+
raise RuntimeError(
837+
"Internal error, len(self._init_iter) should 0 or 1, "
838+
f"but got: {len(self._init_iter)}, {self._init_iter}"
839+
)
829840
iter_counter = self._init_iter.pop() if len(self._init_iter) > 0 else 0
830841
should_exit = False
831842
try:

tests/ignite/engine/test_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def check_iter_and_data():
206206

207207
assert state.epoch == max_epochs
208208
assert not engine.should_terminate
209-
assert state.iteration == real_epoch_length * (max_epochs - 1)
209+
assert state.iteration == real_epoch_length * (max_epochs - 1) + (iteration_to_stop % real_epoch_length)
210210

211211

212212
class RecordedEngine(Engine):

0 commit comments

Comments
 (0)