Skip to content

Commit cca9beb

Browse files
committed
- Merged flags "should_terminate" and "_skip_completed_after_termination".
- Merged flags "should_terminate_single_epoch" and "_skip_epoch_completed_after_termination".
1 parent e2980de commit cca9beb

File tree

2 files changed

+34
-24
lines changed

2 files changed

+34
-24
lines changed

Diff for: ignite/engine/engine.py

+14-18
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,12 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
139139
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
140140
self._process_function = process_function
141141
self.last_event_name: Optional[Events] = None
142-
self.should_terminate = False
143-
self._skip_completed_after_termination = False
144-
self.should_terminate_single_epoch = False
145-
self._skip_epoch_completed_after_termination = False
142+
# should_terminate flag: False - don't terminate, True - terminate,
143+
# "skip_completed" - terminate and skip the event "COMPLETED"
144+
self.should_terminate: bool | str = False
145+
# should_terminate_single_epoch flag: False - don't terminate, True - terminate,
146+
# "skip_epoch_completed" - terminate and skip the event "EPOCH_COMPLETED"
147+
self.should_terminate_single_epoch: bool | str = False
146148
self.should_interrupt = False
147149
self.state = State()
148150
self._state_dict_user_keys: List[str] = []
@@ -626,8 +628,7 @@ def terminate():
626628
Added `skip_completed` flag
627629
"""
628630
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
629-
self.should_terminate = True
630-
self._skip_completed_after_termination = skip_completed
631+
self.should_terminate = "skip_completed" if skip_completed else True
631632

632633
def terminate_epoch(self, skip_epoch_completed: bool = False) -> None:
633634
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
@@ -651,8 +652,7 @@ def terminate_epoch(self, skip_epoch_completed: bool = False) -> None:
651652
"Terminate current epoch is signaled. "
652653
"Current epoch iteration will stop after current iteration is finished."
653654
)
654-
self.should_terminate_single_epoch = True
655-
self._skip_epoch_completed_after_termination = skip_epoch_completed
655+
self.should_terminate_single_epoch = "skip_epoch_completed" if skip_epoch_completed else True
656656

657657
def _handle_exception(self, e: BaseException) -> None:
658658
if Events.EXCEPTION_RAISED in self._event_handlers:
@@ -991,15 +991,14 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
991991
# time is available for handlers but must be updated after fire
992992
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
993993

994-
if not self._skip_epoch_completed_after_termination:
994+
if self.should_terminate_single_epoch != "skip_epoch_completed":
995995
handlers_start_time = time.time()
996996
self._fire_event(Events.EPOCH_COMPLETED)
997997
epoch_time_taken += time.time() - handlers_start_time
998998
# update time wrt handlers
999999
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
1000-
else:
1001-
self._skip_epoch_completed_after_termination = False
10021000

1001+
self.should_terminate_single_epoch = False
10031002
yield from self._maybe_terminate_or_interrupt()
10041003

10051004
hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
@@ -1015,7 +1014,7 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
10151014
self.state.times[Events.COMPLETED.name] = time_taken
10161015

10171016
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1018-
if not (self.should_terminate and self._skip_completed_after_termination):
1017+
if self.should_terminate != "skip_completed":
10191018
handlers_start_time = time.time()
10201019
self._fire_event(Events.COMPLETED)
10211020
time_taken += time.time() - handlers_start_time
@@ -1134,7 +1133,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11341133

11351134
except _EngineTerminateSingleEpochException:
11361135
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
1137-
self.should_terminate_single_epoch = False
11381136
self._setup_dataloader_iter()
11391137

11401138
except _EngineTerminateException as e:
@@ -1180,15 +1178,14 @@ def _internal_run_legacy(self) -> State:
11801178
# time is available for handlers but must be updated after fire
11811179
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
11821180

1183-
if not self._skip_epoch_completed_after_termination:
1181+
if self.should_terminate_single_epoch != "skip_epoch_completed":
11841182
handlers_start_time = time.time()
11851183
self._fire_event(Events.EPOCH_COMPLETED)
11861184
epoch_time_taken += time.time() - handlers_start_time
11871185
# update time wrt handlers
11881186
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
1189-
else:
1190-
self._skip_epoch_completed_after_termination = False
11911187

1188+
self.should_terminate_single_epoch = False
11921189
self._maybe_terminate_legacy()
11931190

11941191
hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
@@ -1204,7 +1201,7 @@ def _internal_run_legacy(self) -> State:
12041201
self.state.times[Events.COMPLETED.name] = time_taken
12051202

12061203
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1207-
if not (self.should_terminate and self._skip_completed_after_termination):
1204+
if self.should_terminate != "skip_completed":
12081205
handlers_start_time = time.time()
12091206
self._fire_event(Events.COMPLETED)
12101207
time_taken += time.time() - handlers_start_time
@@ -1309,7 +1306,6 @@ def _run_once_on_dataset_legacy(self) -> float:
13091306

13101307
except _EngineTerminateSingleEpochException:
13111308
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
1312-
self.should_terminate_single_epoch = False
13131309
self._setup_dataloader_iter()
13141310

13151311
except _EngineTerminateException as e:

Diff for: tests/ignite/engine/test_engine.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,13 @@ def set_interrupt_resume_enabled(self, interrupt_resume_enabled):
4444
def test_terminate(self, skip_completed):
4545
engine = Engine(lambda e, b: 1)
4646
assert not engine.should_terminate
47-
assert not engine._skip_completed_after_termination
47+
4848
engine.terminate(skip_completed)
49-
assert engine.should_terminate
50-
assert engine._skip_completed_after_termination == skip_completed
49+
50+
if skip_completed:
51+
assert engine.should_terminate == "skip_completed"
52+
else:
53+
assert engine.should_terminate == True # noqa: E712
5154

5255
def test_invalid_process_raises_with_invalid_signature(self):
5356
with pytest.raises(ValueError, match=r"Engine must be given a processing function in order to run"):
@@ -338,27 +341,38 @@ def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i, skip_ep
338341

339342
@engine.on(terminate_epoch_event)
340343
def call_terminate_epoch():
341-
assert not engine._skip_epoch_completed_after_termination
344+
assert not engine.should_terminate_single_epoch
342345
nonlocal call_count
343346
if call_count < 1:
344347
engine.terminate_epoch(skip_epoch_completed)
345-
assert engine._skip_epoch_completed_after_termination == skip_epoch_completed
348+
if skip_epoch_completed:
349+
assert engine.should_terminate_single_epoch == "skip_epoch_completed"
350+
else:
351+
assert engine.should_terminate_single_epoch == True # noqa: E712
346352

347353
call_count += 1
348354

355+
@engine.on(Events.EPOCH_STARTED)
356+
def check_skip_reset():
357+
assert engine.should_terminate_single_epoch == False # noqa: E712
358+
349359
@engine.on(Events.TERMINATE_SINGLE_EPOCH)
350360
def check_previous_events(iter_counter):
351361
e = i // len(data) + 1
352362
assert engine.called_events[0] == (0, 0, Events.STARTED)
353363
assert engine.called_events[-2] == (e, i, terminate_epoch_event)
354364
assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH)
355-
assert engine._skip_epoch_completed_after_termination == skip_epoch_completed
365+
if skip_epoch_completed:
366+
assert engine.should_terminate_single_epoch == "skip_epoch_completed"
367+
else:
368+
assert engine.should_terminate_single_epoch == True # noqa: E712
356369

357370
@engine.on(Events.EPOCH_COMPLETED)
358371
def check_previous_events2():
359372
e = i // len(data) + 1
360373
if e == engine.state.epoch and i == engine.state.iteration:
361374
assert not skip_epoch_completed
375+
assert isinstance(engine.should_terminate_single_epoch, bool)
362376
assert engine.called_events[-3] == (e, i, terminate_epoch_event)
363377
assert engine.called_events[-2] == (e, i, Events.TERMINATE_SINGLE_EPOCH)
364378
assert engine.called_events[-1] == (e, i, Events.EPOCH_COMPLETED)

0 commit comments

Comments
 (0)