Skip to content

Commit

Permalink
ensure we save checkpoint at end of loop
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 3, 2024
1 parent 568a3d8 commit 6da42cf
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,11 +949,16 @@ def on_trace_ready(p):
else:
log.info("Training loop complete")

# Save final unsharded model-only checkpoint.
if not canceled and self.cfg.save_interval_unsharded is not None:
log.info("Saving final unsharded model checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
# Save final checkpoint.
if not canceled:
if self.cfg.save_interval_unsharded is not None:
log.info("Saving final unsharded model checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded)
log.info(f"Unsharded checkpoint saved to {checkpoint_path}")
elif self.cfg.save_num_checkpoints_to_keep != 0:
log.info("Saving final checkpoint...")
checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded)
log.info(f"Checkpoint saved to {checkpoint_path}")

def close(self, exit_code: int = 0) -> None:
if self.indices_file is not None:
Expand Down

0 comments on commit 6da42cf

Please sign in to comment.