diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5b3295e99fc8..f965eb688d16 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1352,7 +1352,7 @@ def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank): zero_ckpt_name = os.path.join( checkpoints_path, str(tag), - filename + '_mp_rank_{:02d}'.format(mp_rank) + 'optim_states.pt') + filename + '_mp_rank_{:02d}'.format(mp_rank) + '_optim_states.pt') return zero_ckpt_name def _get_zero_ckpt_name(self, checkpoints_path, tag): @@ -1529,13 +1529,20 @@ def _get_all_zero_checkpoints(self, load_dir, tag): mp_rank=mp_rank, dp_world_size=self.loaded_checkpoint_dp_world_size) invalid_zero_ckpt_paths = [] - for ckpt_name in zero_ckpt_names: + for i, ckpt_name in enumerate(zero_ckpt_names): if not os.path.exists(ckpt_name): + # transparently handle the old file pattern for optim_states + if 'optim_states.pt' in ckpt_name: + ckpt_name_try = ckpt_name.replace("_optim_states.pt", + "optim_states.pt") + if os.path.exists(ckpt_name_try): + zero_ckpt_names[i] = ckpt_name_try + continue invalid_zero_ckpt_paths.append(ckpt_name) if len(invalid_zero_ckpt_paths) > 0: logger.warn( - f"Client provided zero checkpoint load paths: {invalid_zero_ckpt_paths} does not exist" + f"The following zero checkpoints paths are missing: {invalid_zero_ckpt_paths}" ) return None