@@ -139,10 +139,12 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
139
139
self .logger = logging .getLogger (__name__ + "." + self .__class__ .__name__ )
140
140
self ._process_function = process_function
141
141
self .last_event_name : Optional [Events ] = None
142
- self .should_terminate = False
143
- self ._skip_completed_after_termination = False
144
- self .should_terminate_single_epoch = False
145
- self ._skip_epoch_completed_after_termination = False
142
+ # should_terminate flag: False - don't terminate, True - terminate,
143
+ # "skip_completed" - terminate and skip the event "COMPLETED"
144
+ self .should_terminate : bool | str = False
145
+ # should_terminate_single_epoch flag: False - don't terminate, True - terminate,
146
+ # "skip_epoch_completed" - terminate and skip the event "EPOCH_COMPLETED"
147
+ self .should_terminate_single_epoch : bool | str = False
146
148
self .should_interrupt = False
147
149
self .state = State ()
148
150
self ._state_dict_user_keys : List [str ] = []
@@ -626,8 +628,7 @@ def terminate():
626
628
Added `skip_completed` flag
627
629
"""
628
630
self .logger .info ("Terminate signaled. Engine will stop after current iteration is finished." )
629
- self .should_terminate = True
630
- self ._skip_completed_after_termination = skip_completed
631
+ self .should_terminate = "skip_completed" if skip_completed else True
631
632
632
633
def terminate_epoch (self , skip_epoch_completed : bool = False ) -> None :
633
634
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
@@ -651,8 +652,7 @@ def terminate_epoch(self, skip_epoch_completed: bool = False) -> None:
651
652
"Terminate current epoch is signaled. "
652
653
"Current epoch iteration will stop after current iteration is finished."
653
654
)
654
- self .should_terminate_single_epoch = True
655
- self ._skip_epoch_completed_after_termination = skip_epoch_completed
655
+ self .should_terminate_single_epoch = "skip_epoch_completed" if skip_epoch_completed else True
656
656
657
657
def _handle_exception (self , e : BaseException ) -> None :
658
658
if Events .EXCEPTION_RAISED in self ._event_handlers :
@@ -991,15 +991,14 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
991
991
# time is available for handlers but must be updated after fire
992
992
self .state .times [Events .EPOCH_COMPLETED .name ] = epoch_time_taken
993
993
994
- if not self ._skip_epoch_completed_after_termination :
994
+ if self .should_terminate_single_epoch != "skip_epoch_completed" :
995
995
handlers_start_time = time .time ()
996
996
self ._fire_event (Events .EPOCH_COMPLETED )
997
997
epoch_time_taken += time .time () - handlers_start_time
998
998
# update time wrt handlers
999
999
self .state .times [Events .EPOCH_COMPLETED .name ] = epoch_time_taken
1000
- else :
1001
- self ._skip_epoch_completed_after_termination = False
1002
1000
1001
+ self .should_terminate_single_epoch = False
1003
1002
yield from self ._maybe_terminate_or_interrupt ()
1004
1003
1005
1004
hours , mins , secs = _to_hours_mins_secs (epoch_time_taken )
@@ -1015,7 +1014,7 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
1015
1014
self .state .times [Events .COMPLETED .name ] = time_taken
1016
1015
1017
1016
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1018
- if not ( self .should_terminate and self . _skip_completed_after_termination ) :
1017
+ if self .should_terminate != "skip_completed" :
1019
1018
handlers_start_time = time .time ()
1020
1019
self ._fire_event (Events .COMPLETED )
1021
1020
time_taken += time .time () - handlers_start_time
@@ -1134,7 +1133,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
1134
1133
1135
1134
except _EngineTerminateSingleEpochException :
1136
1135
self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
1137
- self .should_terminate_single_epoch = False
1138
1136
self ._setup_dataloader_iter ()
1139
1137
1140
1138
except _EngineTerminateException as e :
@@ -1180,15 +1178,14 @@ def _internal_run_legacy(self) -> State:
1180
1178
# time is available for handlers but must be updated after fire
1181
1179
self .state .times [Events .EPOCH_COMPLETED .name ] = epoch_time_taken
1182
1180
1183
- if not self ._skip_epoch_completed_after_termination :
1181
+ if self .should_terminate_single_epoch != "skip_epoch_completed" :
1184
1182
handlers_start_time = time .time ()
1185
1183
self ._fire_event (Events .EPOCH_COMPLETED )
1186
1184
epoch_time_taken += time .time () - handlers_start_time
1187
1185
# update time wrt handlers
1188
1186
self .state .times [Events .EPOCH_COMPLETED .name ] = epoch_time_taken
1189
- else :
1190
- self ._skip_epoch_completed_after_termination = False
1191
1187
1188
+ self .should_terminate_single_epoch = False
1192
1189
self ._maybe_terminate_legacy ()
1193
1190
1194
1191
hours , mins , secs = _to_hours_mins_secs (epoch_time_taken )
@@ -1204,7 +1201,7 @@ def _internal_run_legacy(self) -> State:
1204
1201
self .state .times [Events .COMPLETED .name ] = time_taken
1205
1202
1206
1203
# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
1207
- if not ( self .should_terminate and self . _skip_completed_after_termination ) :
1204
+ if self .should_terminate != "skip_completed" :
1208
1205
handlers_start_time = time .time ()
1209
1206
self ._fire_event (Events .COMPLETED )
1210
1207
time_taken += time .time () - handlers_start_time
@@ -1309,7 +1306,6 @@ def _run_once_on_dataset_legacy(self) -> float:
1309
1306
1310
1307
except _EngineTerminateSingleEpochException :
1311
1308
self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
1312
- self .should_terminate_single_epoch = False
1313
1309
self ._setup_dataloader_iter ()
1314
1310
1315
1311
except _EngineTerminateException as e :
0 commit comments