Skip to content
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
25e04b3
Merge pull request #1 from microsoft/master
trajepl Jun 23, 2022
bbd2bde
Merge pull request #2 from microsoft/master
trajepl Jul 21, 2022
9dcfb93
Merge pull request #3 from microsoft/master
trajepl Aug 1, 2022
5e74d56
generate empty files for nebula load
trajepl Aug 1, 2022
76e01e1
Merge branch 'microsoft:master' into users/jiapli/load_fix
trajepl Aug 4, 2022
5d41478
Merge branch 'master' into users/jiapli/load_fix
tjruwase Aug 4, 2022
abf04ba
latest tag fix
trajepl Aug 5, 2022
c76c423
Merge branch 'users/jiapli/load_fix' of github.com:trajepl/DeepSpeed …
trajepl Aug 5, 2022
a0ec298
update condition to 'in'
trajepl Aug 8, 2022
2f288af
add some export envs for more platform
Aug 9, 2022
252725a
Merge branch 'master' into users/jiapli/load_fix
tjruwase Aug 9, 2022
67a0293
Merge branch 'master' into users/jiapli/load_fix
tjruwase Aug 9, 2022
a86f824
Merge branch 'master' into users/jiapli/load_fix
tjruwase Aug 9, 2022
e2be404
Merge branch 'master' into users/jiapli/load_fix
tjruwase Aug 9, 2022
97e46bd
fix: diff on ds_latest and nebula latest
trajepl Aug 11, 2022
5c37b4c
Merge branch 'users/jiapli/load_fix' of github.com:trajepl/DeepSpeed …
trajepl Aug 11, 2022
52daf7a
Merge branch 'master' into users/jiapli/load_fix
trajepl Aug 11, 2022
9caf6ca
Merge branch 'master' into users/jiapli/load_fix
trajepl Aug 12, 2022
8b3bd59
Merge branch 'master' into users/jiapli/load_fix
trajepl Aug 16, 2022
cbc0cda
comments for loading back
trajepl Aug 16, 2022
10f4868
Merge branch 'master' into users/jiapli/load_fix
trajepl Aug 16, 2022
b015988
Merge branch 'master' into users/jiapli/load_fix
trajepl Aug 17, 2022
7763468
Merge branch 'master' into users/jiapli/load_fix
trajepl Aug 18, 2022
3b04b7c
revert to export more env vars
Aug 18, 2022
fc15dbd
revert to export more env vars
Aug 18, 2022
c685a81
Merge branch 'master' into users/jiapli/load_fix
trajepl Aug 22, 2022
647ebf5
Merge branch 'master' into users/jiapli/load_fix
tjruwase Aug 22, 2022
14a6cde
Merge branch 'master' into users/jiapli/load_fix
tjruwase Aug 22, 2022
17a0ba1
Merge branch 'master' into users/jiapli/load_fix
trajepl Aug 24, 2022
01caf9f
Merge branch 'master' into users/jiapli/load_fix
trajepl Aug 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ def create(self, tag):
self.checkpoint = torch_nebula.Checkpoint(tag, -2)

def save(self, state_dict, path: str):
log_dist(f"[Nebula] Create dummy files for loading.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main change is here.

torch.save("", path)

tag = _get_tag_from_path(path)
partititon_name = os.path.basename(path)
logger.info(f"[Nebula] Saving {partititon_name} under tag{tag}...")
logger.info(f"[Nebula] Saving {partititon_name} under tag {tag}...")
self.checkpoint.save(partititon_name, state_dict)
logger.info(f"[Nebula] Saved {partititon_name} under tag{tag}.")
logger.info(f"[Nebula] Saved {partititon_name} under tag {tag}.")
return None

def load(self, path: str, map_location=None):
Expand All @@ -50,29 +53,51 @@ def load(self, path: str, map_location=None):
if not self.enable_nebula_load and first_load_flag:
self.tag_flag = tag
logger.info(
f"[Nebula] Disable nebula load. Loading checkpoint from {path}...")
f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
partition = torch.load(path, map_location=map_location)
logger.info(
f"[Nebula] Disable nebula load. Loaded checkpoint from {path}...")
logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .")
return partition

partititon_name = os.path.basename(path)
logger.info(
f"[Nebula] Loading {path} under tag{tag} from {self.nebula_load_path}...")
f"[Nebula] Loading {path} under tag {tag} from nebula path {self.nebula_load_path}..."
)

checkpoint = None
if tag is None:
if tag in (None, 'latest', 'latest_universal'):
# In some cases, there is the inconsistent tag between deepspeed metadata (latest file)
# and nebula metadata, will lead to the failure on loading with deepspeed tag. Then we
# will try to load the valid latest checkpoint from nebula(tier3 > tier1). So, in summary
# when met failure loading for given tag, the loading priority would be like:
# nebula tier3 latest > nebula tier1 latest.
checkpoint = torch_nebula.get_latest_checkpoint(
persist_path=self.nebula_load_path)
if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
logger.warning(f"Unable to find latest valid checkpoint from Nebula!")
return None
else:
checkpoint = torch_nebula.get_checkpoint(tag=tag,
persist_path=self.nebula_load_path)

if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
logger.info(
f"Unable to find valid checkpoint tag:{tag} from Nebula, try to get latest checkpoint again from nebula {self.nebula_load_path} path!"
)
# nebula tier3 latest
checkpoint = torch_nebula.get_latest_checkpoint(
persist_path=self.nebula_load_path)
if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
logger.info(
f"Unable to find latest checkpoint from Nebula tier3, try to get latest checkpoint again from nebula tier1 path!"
)
# nebula tier1 latest
checkpoint = torch_nebula.get_latest_checkpoint()
logger.warning(
f"Unable to find valid checkpoint from Nebula under tag:{tag}.")
return None

tag = checkpoint.tag
self.tag_flag = -1
partition = checkpoint.load(partititon_name, map_location=map_location)
logger.info(
f"[Nebula] Loaded {path} under tag{tag} from {self.nebula_load_path}.")
f"[Nebula] Loaded {path} under tag {tag} from {self.nebula_load_path}.")
return partition

def commit(self, tag):
Expand Down