diff --git a/olmo/train.py b/olmo/train.py index af5c99710..33306480a 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -949,11 +949,16 @@ def on_trace_ready(p): else: log.info("Training loop complete") - # Save final unsharded model-only checkpoint. - if not canceled and self.cfg.save_interval_unsharded is not None: - log.info("Saving final unsharded model checkpoint...") - checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded) - log.info(f"Unsharded checkpoint saved to {checkpoint_path}") + # Save final checkpoint. + if not canceled: + if self.cfg.save_interval_unsharded is not None: + log.info("Saving final unsharded model checkpoint...") + checkpoint_path, _ = self.save_checkpoint(CheckpointType.unsharded) + log.info(f"Unsharded checkpoint saved to {checkpoint_path}") + elif self.cfg.save_num_checkpoints_to_keep != 0: + log.info("Saving final checkpoint...") + checkpoint_path, _ = self.save_checkpoint(CheckpointType.sharded) + log.info(f"Checkpoint saved to {checkpoint_path}") def close(self, exit_code: int = 0) -> None: if self.indices_file is not None: