Skip to content

Commit

Permalink
Add option to start training paused (#3420)
Browse files Browse the repository at this point in the history
  • Loading branch information
aayushg55 authored Sep 10, 2024
1 parent e2a0915 commit e7b7dc9
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
8 changes: 6 additions & 2 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class TrainerConfig(ExperimentConfig):
"""Optionally log gradients during training"""
gradient_accumulation_steps: Dict[str, int] = field(default_factory=lambda: {})
"""Number of steps to accumulate gradients over. Contains a mapping of {param_group:num}"""
start_paused: bool = False
"""Whether to start the training in a paused state."""


class Trainer:
Expand Down Expand Up @@ -121,7 +123,9 @@ def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int =
self.device += f":{local_rank}"
self.mixed_precision: bool = self.config.mixed_precision
self.use_grad_scaler: bool = self.mixed_precision or self.config.use_grad_scaler
self.training_state: Literal["training", "paused", "completed"] = "training"
self.training_state: Literal["training", "paused", "completed"] = (
"paused" if self.config.start_paused else "training"
)
self.gradient_accumulation_steps: DefaultDict = defaultdict(lambda: 1)
self.gradient_accumulation_steps.update(self.config.gradient_accumulation_steps)

Expand Down Expand Up @@ -361,7 +365,7 @@ def _init_viewer_state(self) -> None:
assert self.viewer_state and self.pipeline.datamanager.train_dataset
self.viewer_state.init_scene(
train_dataset=self.pipeline.datamanager.train_dataset,
train_state="training",
train_state=self.training_state,
eval_dataset=self.pipeline.datamanager.eval_dataset,
)

Expand Down
12 changes: 9 additions & 3 deletions nerfstudio/viewer/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def __init__(
self.output_type_changed = True
self.output_split_type_changed = True
self.step = 0
self.train_btn_state: Literal["training", "paused", "completed"] = "training"
self._prev_train_state: Literal["training", "paused", "completed"] = "training"
self.train_btn_state: Literal["training", "paused", "completed"] = (
"training" if self.trainer is None else self.trainer.training_state
)
self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state
self.last_move_time = 0
# track the camera index that last being clicked
self.current_camera_idx = 0
Expand Down Expand Up @@ -174,7 +176,11 @@ def __init__(
)
self.resume_train.on_click(lambda _: self.toggle_pause_button())
self.resume_train.on_click(lambda han: self._toggle_training_state(han))
self.resume_train.visible = False
if self.train_btn_state == "training":
self.resume_train.visible = False
else:
self.pause_train.visible = False

# Add buttons to toggle training image visibility
self.hide_images = self.viser_server.gui.add_button(
label="Hide Train Cams", disabled=False, icon=viser.Icon.EYE_OFF, color=None
Expand Down
6 changes: 4 additions & 2 deletions nerfstudio/viewer_legacy/server/viewer_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ def __init__(
self.output_type_changed = True
self.output_split_type_changed = True
self.step = 0
self.train_btn_state: Literal["training", "paused", "completed"] = "training"
self._prev_train_state: Literal["training", "paused", "completed"] = "training"
self.train_btn_state: Literal["training", "paused", "completed"] = (
"training" if self.trainer is None else self.trainer.training_state
)
self._prev_train_state: Literal["training", "paused", "completed"] = self.train_btn_state

self.camera_message = None

Expand Down

0 comments on commit e7b7dc9

Please sign in to comment.