28
28
from typing import DefaultDict , Dict , List , Literal , Optional , Tuple , Type , cast
29
29
30
30
import torch
31
+ import viser
31
32
from rich import box , style
32
33
from rich .panel import Panel
33
34
from rich .table import Table
@@ -137,6 +138,9 @@ def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int =
137
138
138
139
self .viewer_state = None
139
140
141
+ # used to keep track of the current step
142
+ self .step = 0
143
+
140
144
def setup (self , test_mode : Literal ["test" , "val" , "inference" ] = "val" ) -> None :
141
145
"""Setup the Trainer by calling other setup functions.
142
146
@@ -233,8 +237,15 @@ def train(self) -> None:
233
237
with TimeWriter (writer , EventName .TOTAL_TRAIN_TIME ):
234
238
num_iterations = self .config .max_num_iterations
235
239
step = 0
240
+ self .stop_training = False
236
241
for step in range (self ._start_step , self ._start_step + num_iterations ):
242
+ self .step = step
243
+ if self .stop_training :
244
+ break
237
245
while self .training_state == "paused" :
246
+ if self .stop_training :
247
+ self ._after_train ()
248
+ return
238
249
time .sleep (0.01 )
239
250
with self .train_lock :
240
251
with TimeWriter (writer , EventName .ITER_TRAIN_TIME , step = step ) as train_t :
@@ -291,12 +302,26 @@ def train(self) -> None:
291
302
292
303
writer .write_out_storage ()
293
304
305
+ # save checkpoint at the end of training, and write out any remaining events
306
+ self ._after_train ()
307
+
308
+ def shutdown (self ) -> None :
309
+ """Stop the trainer and stop all associated threads/processes (such as the viewer)."""
310
+ self .stop_training = True # tell the training loop to stop
311
+ if self .viewer_state is not None :
312
+ # stop the viewer
313
+ # this condition excludes the case where `viser_server` is either `None` or an
314
+ # instance of `viewer_legacy`'s `ViserServer` instead of the upstream one.
315
+ if isinstance (self .viewer_state .viser_server , viser .ViserServer ):
316
+ self .viewer_state .viser_server .stop ()
317
+
318
+ def _after_train (self ) -> None :
319
+ """Function to run after training is complete"""
320
+ self .training_state = "completed" # used to update the webui state
294
321
# save checkpoint at the end of training
295
- self .save_checkpoint (step )
296
-
322
+ self .save_checkpoint (self .step )
297
323
# write out any remaining events (e.g., total train time)
298
324
writer .write_out_storage ()
299
-
300
325
table = Table (
301
326
title = None ,
302
327
show_header = False ,
@@ -309,7 +334,7 @@ def train(self) -> None:
309
334
310
335
# after train end callbacks
311
336
for callback in self .callbacks :
312
- callback .run_callback_at_location (step = step , location = TrainingCallbackLocation .AFTER_TRAIN )
337
+ callback .run_callback_at_location (step = self . step , location = TrainingCallbackLocation .AFTER_TRAIN )
313
338
314
339
if not self .config .viewer .quit_on_train_completion :
315
340
self ._train_complete_viewer ()
0 commit comments