diff --git a/nerfstudio/engine/trainer.py b/nerfstudio/engine/trainer.py index ad9f5100b5..644d62a5b5 100644 --- a/nerfstudio/engine/trainer.py +++ b/nerfstudio/engine/trainer.py @@ -86,6 +86,8 @@ class TrainerConfig(ExperimentConfig): """Optionally log gradients during training""" gradient_accumulation_steps: Dict[str, int] = field(default_factory=lambda: {}) """Number of steps to accumulate gradients over. Contains a mapping of {param_group:num}""" + start_paused: bool = False + """Whether to start the training in a paused state.""" class Trainer: @@ -121,7 +123,9 @@ def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int = self.device += f":{local_rank}" self.mixed_precision: bool = self.config.mixed_precision self.use_grad_scaler: bool = self.mixed_precision or self.config.use_grad_scaler - self.training_state: Literal["training", "paused", "completed"] = "training" + self.training_state: Literal["training", "paused", "completed"] = ( + "paused" if self.config.start_paused else "training" + ) self.gradient_accumulation_steps: DefaultDict = defaultdict(lambda: 1) self.gradient_accumulation_steps.update(self.config.gradient_accumulation_steps) @@ -361,7 +365,7 @@ def _init_viewer_state(self) -> None: assert self.viewer_state and self.pipeline.datamanager.train_dataset self.viewer_state.init_scene( train_dataset=self.pipeline.datamanager.train_dataset, - train_state="training", + train_state=self.training_state, eval_dataset=self.pipeline.datamanager.eval_dataset, ) diff --git a/nerfstudio/viewer/viewer.py b/nerfstudio/viewer/viewer.py index 0d23ff151f..bc58043aa6 100644 --- a/nerfstudio/viewer/viewer.py +++ b/nerfstudio/viewer/viewer.py @@ -103,8 +103,10 @@ def __init__( self.output_type_changed = True self.output_split_type_changed = True self.step = 0 - self.train_btn_state: Literal["training", "paused", "completed"] = "training" - self._prev_train_state: Literal["training", "paused", "completed"] = "training" + self.train_btn_state: Literal["training", "paused", "completed"] = ( + "training" if self.trainer is None else self.trainer.training_state + ) + self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state self.last_move_time = 0 # track the camera index that last being clicked self.current_camera_idx = 0 @@ -174,7 +176,11 @@ def __init__( ) self.resume_train.on_click(lambda _: self.toggle_pause_button()) self.resume_train.on_click(lambda han: self._toggle_training_state(han)) - self.resume_train.visible = False + if self.train_btn_state == "training": + self.resume_train.visible = False + else: + self.pause_train.visible = False + # Add buttons to toggle training image visibility self.hide_images = self.viser_server.gui.add_button( label="Hide Train Cams", disabled=False, icon=viser.Icon.EYE_OFF, color=None diff --git a/nerfstudio/viewer_legacy/server/viewer_state.py b/nerfstudio/viewer_legacy/server/viewer_state.py index cfb3bff7b1..990a7bbf1d 100644 --- a/nerfstudio/viewer_legacy/server/viewer_state.py +++ b/nerfstudio/viewer_legacy/server/viewer_state.py @@ -116,8 +116,10 @@ def __init__( self.output_type_changed = True self.output_split_type_changed = True self.step = 0 - self.train_btn_state: Literal["training", "paused", "completed"] = "training" - self._prev_train_state: Literal["training", "paused", "completed"] = "training" + self.train_btn_state: Literal["training", "paused", "completed"] = ( + "training" if self.trainer is None else self.trainer.training_state + ) + self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state self.camera_message = None