diff --git a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py index b7ff5cdf9829..4cdd0856e7b7 100644 --- a/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py @@ -37,11 +37,14 @@ def create(self, tag): self.checkpoint = torch_nebula.Checkpoint(tag, -2) def save(self, state_dict, path: str): + log_dist(f"[Nebula] Create dummy files for loading.") + torch.save("", path) + tag = _get_tag_from_path(path) partititon_name = os.path.basename(path) - logger.info(f"[Nebula] Saving {partititon_name} under tag{tag}...") + logger.info(f"[Nebula] Saving {partititon_name} under tag {tag}...") self.checkpoint.save(partititon_name, state_dict) - logger.info(f"[Nebula] Saved {partititon_name} under tag{tag}.") + logger.info(f"[Nebula] Saved {partititon_name} under tag {tag}.") return None def load(self, path: str, map_location=None): @@ -50,29 +53,51 @@ def load(self, path: str, map_location=None): if not self.enable_nebula_load and first_load_flag: self.tag_flag = tag logger.info( - f"[Nebula] Disable nebula load. Loading checkpoint from {path}...") + f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...") partition = torch.load(path, map_location=map_location) - logger.info( - f"[Nebula] Disable nebula load. Loaded checkpoint from {path}...") + logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .") return partition partititon_name = os.path.basename(path) logger.info( - f"[Nebula] Loading {path} under tag{tag} from {self.nebula_load_path}...") + f"[Nebula] Loading {path} under tag {tag} from nebula path {self.nebula_load_path}..." + ) checkpoint = None - if tag is None: + if tag in (None, 'latest', 'latest_universal'): + # In some cases, there is the inconsistent tag between deepspeed metadata (latest file) + # and nebula metadata, will lead to the failure on loading with deepspeed tag. Then we + # will try to load the valid latest checkpoint from nebula(tier3 > tier1). So, in summary + # when met failure loading for given tag, the loading priority would be like: + # nebula tier3 latest > nebula tier1 latest. checkpoint = torch_nebula.get_latest_checkpoint( persist_path=self.nebula_load_path) - if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''): - logger.warning(f"Unable to find latest valid checkpoint from Nebula!") - return None else: checkpoint = torch_nebula.get_checkpoint(tag=tag, persist_path=self.nebula_load_path) + + if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''): + logger.info( + f"Unable to find valid checkpoint tag:{tag} from Nebula, try to get latest checkpoint again from nebula {self.nebula_load_path} path!" + ) + # nebula tier3 latest + checkpoint = torch_nebula.get_latest_checkpoint( + persist_path=self.nebula_load_path) + if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''): + logger.info( + f"Unable to find latest checkpoint from Nebula tier3, try to get latest checkpoint again from nebula tier1 path!" + ) + # nebula tier1 latest + checkpoint = torch_nebula.get_latest_checkpoint() + logger.warning( + f"Unable to find valid checkpoint from Nebula under tag:{tag}.") + return None + + tag = checkpoint.tag + self.tag_flag = -1 partition = checkpoint.load(partititon_name, map_location=map_location) logger.info( - f"[Nebula] Loaded {path} under tag{tag} from {self.nebula_load_path}.") + f"[Nebula] Loaded {path} under tag {tag} from {self.nebula_load_path}.") return partition def commit(self, tag):