Skip to content

Commit

Permalink
fix: canceled vs cancel_initiated
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Nov 6, 2023
1 parent b828938 commit a1c32e9
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,9 +772,10 @@ def on_trace_ready(p):

# Train.
first_batch: bool = True
cancel_initiated: bool = False
canceled: bool = False
hard_stop: bool = False
stop_at: Optional[int] = self.cfg.stop_at
save_checkpoints: bool = True

with torch_profiler as p:
for batch in self.train_loader:
Expand Down Expand Up @@ -828,20 +829,24 @@ def on_trace_ready(p):
wandb.log(metrics, step=self.global_step)

# Check if/when run should be canceled.
if not canceled and self.global_step % self.cfg.canceled_check_interval == 0:
canceled, extra_steps = self.check_if_cancelled()
if not cancel_initiated and self.global_step % self.cfg.canceled_check_interval == 0:
cancel_initiated, extra_steps = self.check_if_cancelled()
stop_at = (
self.global_step + extra_steps
if stop_at is None
else min(self.global_step + extra_steps, stop_at)
)

if stop_at is not None and self.global_step >= stop_at:
canceled = hard_stop = True
canceled = True

# Maybe save sharded checkpoint.
if canceled or (
self.global_step % self.cfg.save_interval == 0 and self.cfg.save_num_checkpoints_to_keep != 0
if save_checkpoints and (
cancel_initiated
or (
self.global_step % self.cfg.save_interval == 0
and self.cfg.save_num_checkpoints_to_keep != 0
)
):
log.info("Saving checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
Expand All @@ -850,9 +855,13 @@ def on_trace_ready(p):
# Reset speed monitor so that we don't count the time taken to save checkpoints.
speed_monitor.reset()

# If the run was just canceled this will be the final checkpoint.
if cancel_initiated:
save_checkpoints = False

# Maybe save unsharded checkpoint.
if (
not canceled # we already save a sharded checkpoint when canceled
save_checkpoints
and self.cfg.save_interval_unsharded is not None
and self.global_step % self.cfg.save_interval_unsharded == 0
and self.cfg.save_num_unsharded_checkpoints_to_keep != 0
Expand All @@ -865,7 +874,7 @@ def on_trace_ready(p):
speed_monitor.reset()

# Maybe run evaluations.
if not canceled and self.global_step % self.cfg.eval_interval == 0:
if not cancel_initiated and self.global_step % self.cfg.eval_interval == 0:
eval_metrics = self.eval()

# Log metrics to W&B.
Expand All @@ -883,10 +892,7 @@ def on_trace_ready(p):
if p is not None:
p.step()

if hard_stop:
# End training loop due to cancellation or reaching `cfg.stop_at`.
# NOTE: no need to check `canceled` here since we always set `stop_at` when the run
# is canceled and then set `hard_stop` based on `stop_at`.
if canceled:
break

# Python Profiler stuff
Expand All @@ -902,7 +908,7 @@ def on_trace_ready(p):
log.info("Training loop complete")

# Save final unsharded model-only checkpoint.
if not canceled and self.cfg.save_interval_unsharded is not None:
if save_checkpoints and self.cfg.save_interval_unsharded is not None:
log.info("Saving final unsharded model checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
Expand Down

0 comments on commit a1c32e9

Please sign in to comment.