Skip to content

Commit eddf2d2

Browse files
KevinXu02brentyi
andauthored
Changes for trainer.py to support the Gradio webui (nerfstudio-project#3046)
* changes for trainer to support webui * Update trainer to support webui * format * add a seperated shutdown() function to stop training * typo fix * get rid of _stop_viewer_server() * Update trainer.py * organize import --------- Co-authored-by: Brent Yi <[email protected]>
1 parent 45d8bb7 commit eddf2d2

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

nerfstudio/engine/trainer.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from typing import DefaultDict, Dict, List, Literal, Optional, Tuple, Type, cast
2929

3030
import torch
31+
import viser
3132
from rich import box, style
3233
from rich.panel import Panel
3334
from rich.table import Table
@@ -137,6 +138,9 @@ def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int =
137138

138139
self.viewer_state = None
139140

141+
# used to keep track of the current step
142+
self.step = 0
143+
140144
def setup(self, test_mode: Literal["test", "val", "inference"] = "val") -> None:
141145
"""Setup the Trainer by calling other setup functions.
142146
@@ -233,8 +237,15 @@ def train(self) -> None:
233237
with TimeWriter(writer, EventName.TOTAL_TRAIN_TIME):
234238
num_iterations = self.config.max_num_iterations
235239
step = 0
240+
self.stop_training = False
236241
for step in range(self._start_step, self._start_step + num_iterations):
242+
self.step = step
243+
if self.stop_training:
244+
break
237245
while self.training_state == "paused":
246+
if self.stop_training:
247+
self._after_train()
248+
return
238249
time.sleep(0.01)
239250
with self.train_lock:
240251
with TimeWriter(writer, EventName.ITER_TRAIN_TIME, step=step) as train_t:
@@ -291,12 +302,26 @@ def train(self) -> None:
291302

292303
writer.write_out_storage()
293304

305+
# save checkpoint at the end of training, and write out any remaining events
306+
self._after_train()
307+
308+
def shutdown(self) -> None:
309+
"""Stop the trainer and stop all associated threads/processes (such as the viewer)."""
310+
self.stop_training = True # tell the training loop to stop
311+
if self.viewer_state is not None:
312+
# stop the viewer
313+
# this condition excludes the case where `viser_server` is either `None` or an
314+
# instance of `viewer_legacy`'s `ViserServer` instead of the upstream one.
315+
if isinstance(self.viewer_state.viser_server, viser.ViserServer):
316+
self.viewer_state.viser_server.stop()
317+
318+
def _after_train(self) -> None:
319+
"""Function to run after training is complete"""
320+
self.training_state = "completed" # used to update the webui state
294321
# save checkpoint at the end of training
295-
self.save_checkpoint(step)
296-
322+
self.save_checkpoint(self.step)
297323
# write out any remaining events (e.g., total train time)
298324
writer.write_out_storage()
299-
300325
table = Table(
301326
title=None,
302327
show_header=False,
@@ -309,7 +334,7 @@ def train(self) -> None:
309334

310335
# after train end callbacks
311336
for callback in self.callbacks:
312-
callback.run_callback_at_location(step=step, location=TrainingCallbackLocation.AFTER_TRAIN)
337+
callback.run_callback_at_location(step=self.step, location=TrainingCallbackLocation.AFTER_TRAIN)
313338

314339
if not self.config.viewer.quit_on_train_completion:
315340
self._train_complete_viewer()

0 commit comments

Comments
 (0)