Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 80 additions & 19 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@
from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler
from deepspeed.utils.logging import print_json_dist

# SCR: import Scalable Checkpoint/Restart library
import scr

# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
dist = None

Expand Down Expand Up @@ -258,6 +261,17 @@ def __init__(
if not dist.is_initialized():
dist.init_process_group(backend=self.dist_backend)

# SCR: capture SCR config settings
self.use_scr = args.scr
self.scr_output_interval = args.save_interval

# Moved scr.init/finalize to megatron
# SCR: initialize SCR
# if self.use_scr:
# # DeepSpeed expects checkpoint files to be in a global file system on restart
# scr.config("SCR_GLOBAL_RESTART=1")
# scr.init()

self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
Expand Down Expand Up @@ -2570,23 +2584,32 @@ def load_checkpoint(self,
"""

if tag is None:
latest_tag = "latest_universal" if self.load_universal_checkpoint(
) else "latest"
latest_path = os.path.join(load_dir, latest_tag)
if os.path.isfile(latest_path):
with open(latest_path, "r") as fd:
tag = fd.read().strip()
if self.use_scr:
# SCR: get name of latest checkpoint from SCR
tag = scr.have_restart()
if tag is None:
if self.global_rank == 0:
logger.warning(f"SCR unable to find checkpoint")
return None, None
if self.global_rank == 0:
logger.info(f"SCR found dataset named '{tag}'")
else:
if self.load_universal_checkpoint():
raise ValueError(
f'Invalid for universal checkpoint: {latest_path} does not exist'
)
latest_tag = "latest_universal" if self.load_universal_checkpoint() else "latest"
latest_path = os.path.join(load_dir, latest_tag)
if os.path.isfile(latest_path):
with open(latest_path, "r") as fd:
tag = fd.read().strip()
else:
logger.warning(
f"Unable to find latest file at {latest_path}, if trying to load latest "
"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint."
)
return None, None
if self.load_universal_checkpoint():
raise ValueError(
f'Invalid for universal checkpoint: {latest_path} does not exist'
)
else:
logger.warning(
f"Unable to find latest file at {latest_path}, if trying to load latest "
"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint."
)
return None, None

if self.zero_optimization_partition_weights():
# Prepare for checkpoint load by ensuring all parameters are partitioned
Expand Down Expand Up @@ -2908,8 +2931,11 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)
# This is to make sure the checkpoint names are created without collision
# There seems to be issue creating them in parallel

# SCR: avoid creating directory since SCR will do that as needed on flush
# Ensure save_dir directory exists
os.makedirs(save_dir, exist_ok=True)
if not self.use_scr:
os.makedirs(save_dir, exist_ok=True)

dist.barrier()

if tag is None:
Expand All @@ -2922,6 +2948,16 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)
# Ensure checkpoint tag is consistent across ranks
self._checkpoint_tag_validation(tag)

# SCR: start checkpoint, use tag as the dataset name
valid = True
if self.use_scr:
# Consider checkpoint to be defensive unless the global step count
# is divisible by the save_interval, then also mark as output to force flush
scr_flags = scr.FLAG_CHECKPOINT
if self.global_steps % self.scr_output_interval == 0:
scr_flags |= scr.FLAG_OUTPUT
scr.start_output(tag, scr_flags)

if self.has_moe_layers:
self.save_non_zero_checkpoint = False
self._create_checkpoint_file(save_dir, tag, False)
Expand All @@ -2943,13 +2979,22 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)
self.optimizer.checkpoint_event_epilogue()

# Save latest checkpoint tag
# SCR: Avoid writing the latest file when using SCR.
# It can't be written as an SCR file,
# since each checkpoint writes to this same file.
# Instead, SCR returns the tag value during restart via scr.have_restart().
self.checkpoint_engine.commit(tag)
if save_latest and rank == 0:
if save_latest and rank == 0 and not self.use_scr:
with open(os.path.join(save_dir, 'latest'), 'w') as fd:
fd.write(tag)

dist.barrier()

# TODO: Set valid=False if calling rank failed to write any of its checkpoint files.
# SCR: complete checkpoint
if self.use_scr:
scr.complete_output(valid)

return True

def _get_non_moe_state_dict(self, full_state_dict):
Expand Down Expand Up @@ -3073,6 +3118,10 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}):
self._curr_save_path = None

def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
# SCR: skip creating directory since SCR will create it as needed during flush
if self.use_scr:
return True

name_function = (self._get_zero_ckpt_name
if zero_checkpoint else self._get_ckpt_name)
try:
Expand Down Expand Up @@ -3129,6 +3178,11 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):

if self.save_non_zero_checkpoint:
log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])

# SCR: register checkpoint with SCR, and get path to open file from SCR
if self.use_scr:
save_path = scr.route_file(save_path)

self.checkpoint_engine.save(state, save_path)

def _get_buffer_names(self):
Expand Down Expand Up @@ -3209,10 +3263,17 @@ def _save_zero_checkpoint(self, save_path, tag):
zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(),
ds_config=self.config,
ds_version=version)

# SCR: register checkpoint with SCR, and get path to open file from SCR
if self.use_scr:
zero_checkpoint_name = scr.route_file(zero_checkpoint_name)

self.checkpoint_engine.save(zero_sd, zero_checkpoint_name)

if self.global_rank == 0:
self._copy_recovery_script(save_path)
# TODO: fixme
# SCR: disable copying of recovery script
#if self.global_rank == 0:
# self._copy_recovery_script(save_path)
ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero'
logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}')

Expand Down
2 changes: 2 additions & 0 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):

self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline

self.module.use_scr = self.use_scr

if self.is_last_stage():
self.loss_model = self.module.loss_fn

Expand Down
13 changes: 12 additions & 1 deletion deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from .topology import PipeDataParallelTopology, PipelineParallelGrid
from deepspeed.runtime.state_dict_factory import SDLoaderFactory

# SCR: import Scalable Checkpoint/Restart library
import scr


class PipelineError(Exception):
"""Errors related to the use of deepspeed.PipelineModule """
Expand Down Expand Up @@ -581,7 +584,10 @@ def save_state_dict(self, save_dir, checkpoint_engine):
start, end = 0, num_layers
layer_list = self.forward_funcs[start:end]

os.makedirs(save_dir, exist_ok=True)
# SCR: skip makedirs since SCR will create them as needed during the flush
if not self.use_scr:
os.makedirs(save_dir, exist_ok=True)

for idx, layer in enumerate(layer_list):
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
if not hasattr(layer, 'state_dict'):
Expand All @@ -597,6 +603,11 @@ def save_state_dict(self, save_dir, checkpoint_engine):
{k: v.clone()
for k,
v in orig_state_dict.items()})

# SCR: register checkpoint file and get path to write file from SCR
if self.use_scr:
model_ckpt_path = scr.route_file(model_ckpt_path)

checkpoint_engine.save(final_state_dict, model_ckpt_path)

def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
Expand Down