Skip to content

Commit

Permalink
Added optional flag skip_epoch_completed to Engine.terminate_epoch()
Browse files Browse the repository at this point in the history
  • Loading branch information
bonassifabio committed Dec 4, 2024
1 parent 6f8ad2a commit adb2a01
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 25 deletions.
36 changes: 25 additions & 11 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
self.should_terminate = False
self.skip_completed_after_termination = False
self.should_terminate_single_epoch = False
self._skip_epoch_completed_after_termination = False
self.should_interrupt = False
self.state = State()
self._state_dict_user_keys: List[str] = []
Expand Down Expand Up @@ -628,7 +629,7 @@ def terminate():
self.should_terminate = True
self.skip_completed_after_termination = skip_completed

def terminate_epoch(self) -> None:
def terminate_epoch(self, skip_epoch_completed: bool = False) -> None:
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
continues from the next epoch. The following events are triggered:
Expand All @@ -638,12 +639,17 @@ def terminate_epoch(self) -> None:
- :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
- :attr:`~ignite.engine.events.Events.EPOCH_STARTED`
- ...
Args:
skip_epoch_completed: if True, the event :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
is not fired after :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH`. Default is False.
"""
self.logger.info(
"Terminate current epoch is signaled. "
"Current epoch iteration will stop after current iteration is finished."
)
self.should_terminate_single_epoch = True
self._skip_epoch_completed_after_termination = skip_epoch_completed

def _handle_exception(self, e: BaseException) -> None:
if Events.EXCEPTION_RAISED in self._event_handlers:
Expand Down Expand Up @@ -982,11 +988,15 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
# time is available for handlers but must be updated after fire
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
if not self._skip_epoch_completed_after_termination:
handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
else:
self._skip_epoch_completed_after_termination = False

yield from self._maybe_terminate_or_interrupt()

hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
Expand Down Expand Up @@ -1167,11 +1177,15 @@ def _internal_run_legacy(self) -> State:
# time is available for handlers but must be updated after fire
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
if not self._skip_epoch_completed_after_termination:
handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
else:
self._skip_epoch_completed_after_termination = False

self._maybe_terminate_legacy()

hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
Expand Down
12 changes: 9 additions & 3 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ class Events(EventEnum):
- TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch,
after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or
:meth:`~ignite.engine.engine.Engine.terminate()` call.
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
- EPOCH_COMPLETED : triggered when the epoch is ended. This is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called,
unless the flag `skip_epoch_completed` is set to True.
- TERMINATE : triggered when the run is about to end completely,
after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call.
Expand All @@ -272,7 +273,7 @@ class Events(EventEnum):
The table below illustrates which events are triggered when various termination methods are called.
.. list-table::
:widths: 35 38 28 20 20
:widths: 38 38 28 20 20
:header-rows: 1
* - Method
Expand All @@ -290,6 +291,11 @@ class Events(EventEnum):
- ✔
- ✗
- ✔
* - :meth:`~ignite.engine.engine.Engine.terminate_epoch()` with `skip_epoch_completed=True`
- ✔
- ✗
- ✗
- ✔
* - :meth:`~ignite.engine.engine.Engine.terminate()`
- ✗
- ✔
Expand Down
36 changes: 25 additions & 11 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,16 +292,19 @@ def assert_no_exceptions(ee):
assert engine.called_events[0] == (0, 0, Events.STARTED)
assert engine._dataloader_iter is None

@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])
def test_terminate_epoch_stops_mid_epoch(self, data, epoch_length):
@pytest.mark.parametrize(
"data, epoch_length, skip_epoch_completed",
[(None, 10, False), (range(10), None, False), (None, 10, True), (range(10), None, True)],
)
def test_terminate_epoch_stops_mid_epoch(self, data, epoch_length, skip_epoch_completed):
real_epoch_length = epoch_length if data is None else len(data)
iteration_to_stop = real_epoch_length + 4

engine = Engine(MagicMock(return_value=1))

def start_of_iteration_handler(engine):
if engine.state.iteration == iteration_to_stop:
engine.terminate_epoch()
engine.terminate_epoch(skip_epoch_completed)

max_epochs = 3
engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler)
Expand All @@ -312,15 +315,19 @@ def start_of_iteration_handler(engine):
assert state.epoch == max_epochs

@pytest.mark.parametrize(
"terminate_epoch_event, i",
"terminate_epoch_event, i, skip_epoch_completed",
[
(Events.GET_BATCH_STARTED(once=12), 12),
(Events.GET_BATCH_COMPLETED(once=12), 12),
(Events.ITERATION_STARTED(once=14), 14),
(Events.ITERATION_COMPLETED(once=14), 14),
(Events.GET_BATCH_STARTED(once=12), 12, False),
(Events.GET_BATCH_COMPLETED(once=12), 12, False),
(Events.ITERATION_STARTED(once=14), 14, False),
(Events.ITERATION_COMPLETED(once=14), 14, False),
(Events.GET_BATCH_STARTED(once=12), 12, True),
(Events.GET_BATCH_COMPLETED(once=12), 12, True),
(Events.ITERATION_STARTED(once=14), 14, True),
(Events.ITERATION_COMPLETED(once=14), 14, True),
],
)
def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i):
def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i, skip_epoch_completed):
engine = RecordedEngine(MagicMock(return_value=1))
data = range(10)
max_epochs = 3
Expand All @@ -331,23 +338,27 @@ def test_terminate_epoch_events_sequence(self, terminate_epoch_event, i):

@engine.on(terminate_epoch_event)
def call_terminate_epoch():
assert not engine._skip_epoch_completed_after_termination
nonlocal call_count
if call_count < 1:
engine.terminate_epoch()
engine.terminate_epoch(skip_epoch_completed)
assert engine._skip_epoch_completed_after_termination == skip_epoch_completed

call_count += 1

@engine.on(Events.TERMINATE_SINGLE_EPOCH)
def check_previous_events(iter_counter):
e = i // len(data) + 1

assert engine.called_events[0] == (0, 0, Events.STARTED)
assert engine.called_events[-2] == (e, i, terminate_epoch_event)
assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH)
assert engine._skip_epoch_completed_after_termination == skip_epoch_completed

@engine.on(Events.EPOCH_COMPLETED)
def check_previous_events2():
e = i // len(data) + 1
if e == engine.state.epoch and i == engine.state.iteration:
assert not skip_epoch_completed
assert engine.called_events[-3] == (e, i, terminate_epoch_event)
assert engine.called_events[-2] == (e, i, Events.TERMINATE_SINGLE_EPOCH)
assert engine.called_events[-1] == (e, i, Events.EPOCH_COMPLETED)
Expand All @@ -357,6 +368,9 @@ def check_previous_events2():
assert engine.state.epoch == max_epochs
assert (max_epochs - 1) * len(data) < engine.state.iteration < max_epochs * len(data)

epoch_completed_events = [e for e in engine.called_events if e[2] == Events.EPOCH_COMPLETED.name]
assert len(epoch_completed_events) == max_epochs - skip_epoch_completed

@pytest.mark.parametrize("data", [None, "mock_data_loader"])
def test_iteration_events_are_fired(self, data):
max_epochs = 5
Expand Down

0 comments on commit adb2a01

Please sign in to comment.