From 905359e60bcfb22f1a2d5e7a1c3628ec470f0e59 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 11 Jan 2024 08:09:25 -0800 Subject: [PATCH] fix bug with saving unsharded checkpoint --- olmo/checkpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 1e2aba816..58bd4f736 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -494,6 +494,8 @@ def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]: checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp") if get_fs_local_rank() == 0: shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True) + checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True) + barrier() # Yield temporary directory for `.save_checkpoint()` to use. @@ -502,10 +504,8 @@ def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]: barrier() # Finally if all went well replace the temporary directory with the actual - # checkpoint directory. Note that for some checkpointers the local rank 0 might - # not use this folder, so it may not exist; FullCheckpointer, for example, only creates - # this for global rank 0. - if get_fs_local_rank() == 0 and checkpoint_dir_tmp.exists(): + # checkpoint directory. + if get_fs_local_rank() == 0: # Replace temp directory with target checkpoint directory. try: checkpoint_dir_tmp.replace(checkpoint_dir)