Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
25e04b3
Merge pull request #1 from microsoft/master
trajepl Jun 23, 2022
d88e591
enable checkpoint engine
trajepl Jun 30, 2022
07e59d6
seprated nebula config
trajepl Jul 11, 2022
4cbdfe6
add __init__.py for nebula importing
trajepl Jul 11, 2022
1f2f40c
linter fix
trajepl Jul 11, 2022
d900145
fix: ds_config is None
trajepl Jul 11, 2022
b44832b
fix: ds config
trajepl Jul 11, 2022
e4a57bd
fix: get sd loader fix
trajepl Jul 11, 2022
85e52f5
Merge branch 'master' into trajepl/nebula_ckpt_engine
tjruwase Jul 11, 2022
d70bcd1
align the API with torch raw code
trajepl Jul 12, 2022
4c50308
Merge branch 'trajepl/nebula_ckpt_engine' of github.com:trajepl/DeepS…
trajepl Jul 12, 2022
5d987a0
linter fix
trajepl Jul 12, 2022
a04a81a
remove duplicate tag params
trajepl Jul 12, 2022
21b70bd
Merge branch 'master' into trajepl/nebula_ckpt_engine
mrwyattii Jul 18, 2022
81ccd07
Merge branch 'master' into trajepl/nebula_ckpt_engine
tjruwase Jul 19, 2022
4b42bc2
make checkpoint_engine as required args
trajepl Jul 21, 2022
22f8c2a
Merge branch 'trajepl/nebula_ckpt_engine' of github.com:trajepl/DeepS…
trajepl Jul 21, 2022
bbd2bde
Merge pull request #2 from microsoft/master
trajepl Jul 21, 2022
d9298cf
Merge branch 'master' of github.com:trajepl/DeepSpeed into trajepl/ne…
trajepl Jul 21, 2022
1906398
fix args
trajepl Jul 21, 2022
432e7c6
extract parameters out to config
trajepl Jul 21, 2022
7dbb6d8
fix: load state dict
trajepl Jul 21, 2022
e912e31
separate load engine
trajepl Jul 21, 2022
7fc279b
linter fix
trajepl Jul 22, 2022
5ebacc6
Merge branch 'master' into trajepl/nebula_ckpt_engine
tjruwase Jul 22, 2022
c70c818
extract checkpoint engine to abstract calss
trajepl Jul 22, 2022
e6dd794
linter fix
trajepl Jul 22, 2022
3788ada
Merge branch 'trajepl/nebula_ckpt_engine' of github.com:trajepl/DeepS…
trajepl Jul 22, 2022
1efd2ce
construct function args fix
trajepl Jul 22, 2022
dce0fb5
add docs for dev/customers
trajepl Jul 25, 2022
bb5bb7c
linter fix
trajepl Jul 25, 2022
0c21dc2
Merge branch 'master' into trajepl/nebula_ckpt_engine
tjruwase Jul 25, 2022
3e8c238
remove load engine
trajepl Jul 26, 2022
a5c8897
print->log_dist
trajepl Jul 26, 2022
44d687b
linter fix
trajepl Jul 26, 2022
82ad297
add tag flag to distinguish the loading order
trajepl Jul 26, 2022
cf12a8d
Merge branch 'master' into trajepl/nebula_ckpt_engine
tjruwase Jul 26, 2022
422221b
Merge branch 'master' into trajepl/nebula_ckpt_engine
tjruwase Jul 27, 2022
340de11
Merge branch 'master' into trajepl/nebula_ckpt_engine
tjruwase Jul 27, 2022
7f3f14c
Merge branch 'master' into trajepl/nebula_ckpt_engine
tjruwase Jul 27, 2022
5071091
Merge branch 'master' into trajepl/nebula_ckpt_engine
jeffra Jul 27, 2022
1b43df5
Merge branch 'master' into trajepl/nebula_ckpt_engine
tjruwase Jul 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from torch.nn.modules import Module
from packaging import version as pkg_version
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine

from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(self,
self.expert_mp_group = expert_mp_group
self.enable_cuda_graph = enable_cuda_graph
self.cuda_graph_created = False
self.checkpoint_engine = TorchCheckpointEngine()
self._init_quantization_setting(quantization_setting)

if enable_cuda_graph:
Expand Down Expand Up @@ -376,9 +378,10 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
tag = fd.read().strip()

ckpt_list = self._get_all_ckpt_names(load_dir, tag)
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list)
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine)
else:
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir)
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir,
self.checkpoint_engine)

if type(sd_loader) is list:
self.sd = torch.load(sd_loader[0], map_location='cpu')
Expand Down Expand Up @@ -416,10 +419,12 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
state_dict=checkpoint[self._choose_module_key(checkpoint)],
old_moe_load=old_moe_load,
model=self.module,
mpu=self.mpu)
mpu=self.mpu,
checkpoint_engine=self.checkpoint_engine)

self.module.load_state_dict(
state_dict=checkpoint[self._choose_module_key(checkpoint)],
checkpoint_engine=self.checkpoint_engine,
strict=load_module_strict)

def _choose_module_key(self, sd):
Expand Down
12 changes: 12 additions & 0 deletions deepspeed/launcher/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from argparse import ArgumentParser, REMAINDER

from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..nebula.constants import DLTS_POD_ENV_PATH
from ..utils import logger

PID_FILE_BASEPATH = "/tmp"
Expand Down Expand Up @@ -143,6 +144,17 @@ def main():
with open(pid_file, 'w') as fd:
fd.write(f"{launcher_pid}")

if os.path.exists(DLTS_POD_ENV_PATH):
with open(DLTS_POD_ENV_PATH) as file:
lines = file.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
if line.startswith('export FC_TASKROLE_NAME') or line.startswith(
'export FC_TASK_INDEX'):
key_val = line.split()[1]
key, val = key_val.split('=')
current_env[key] = val

processes = []
cmd = []
for local_rank in range(0, num_local_procs):
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner
from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..nebula.constants import NEBULA_EXPORT_ENVS
from ..utils import logger

from ..autotuning import Autotuner

DLTS_HOSTFILE = "/job/hostfile"
EXPORT_ENVS = ["NCCL", "PYTHON", "MV2", "UCX"]
EXPORT_ENVS = ['NCCL', 'PYTHON', 'MV2', 'UCX']
EXPORT_ENVS += NEBULA_EXPORT_ENVS
DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env"
DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.']
PDSH_MAX_FAN_OUT = 1024
Expand Down
Empty file added deepspeed/nebula/__init__.py
Empty file.
53 changes: 53 additions & 0 deletions deepspeed/nebula/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""

from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
from deepspeed.nebula.constants import *


class DeepSpeedNebulaConfig(DeepSpeedConfigObject):
def __init__(self, param_dict):
super(DeepSpeedNebulaConfig, self).__init__()

self.enabled = None
self.persistent_storage_path = None
self.persistent_time_interval = None
self.num_of_version_in_retention = None
self.enable_nebula_load = None

if NEBULA in param_dict.keys():
nebula_dict = param_dict[NEBULA]
else:
nebula_dict = {}

self._initialize(nebula_dict)

def _initialize(self, nebula_dict):
self.enabled = get_scalar_param(nebula_dict,
NEBULA_ENABLED,
NEBULA_ENABLED_DEFAULT)

self.load_path = get_scalar_param(nebula_dict,
NEBULA_LOAD_PATH,
NEBULA_LOAD_PATH_DEFAULT)

self.enable_nebula_load = get_scalar_param(nebula_dict,
NEBULA_ENABLE_NEBULA_LOAD,
NEBULA_ENABLE_NEBULA_LOAD_DEFAULT)

self.persistent_storage_path = get_scalar_param(
nebula_dict,
NEBULA_PERSISTENT_STORAGE_PATH,
NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT)

self.persistent_time_interval = get_scalar_param(
nebula_dict,
NEBULA_PERSISTENT_TIME_INTERVAL,
NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT)

self.num_of_version_in_retention = get_scalar_param(
nebula_dict,
NEBULA_NUM_OF_VERSION_IN_RETENTION,
NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT)
86 changes: 86 additions & 0 deletions deepspeed/nebula/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""

#########################################
# nebula
#########################################
# Nebula. By default, this feature is not enabled.
# Users can configure in ds_config.json as below example:
NEBULA_FORMAT = '''
nebula should be enabled as:
"session_params": {
"nebula": {
"enabled": true,
"persistent_storage_path": "/foo/bar",
"persistent_time_interval": 100,
"num_of_version_in_retention": 2,
"enable_nebula_load": true
}
}
'''

NEBULA = "nebula"

NEBULA_ENABLED = "enabled"
NEBULA_ENABLED_DEFAULT = False

# There is a case where customer want to load the checkpoint saved
# by raw torch. Because nebula cannot load torch checkpoint directly
# as they have different folder structures to bring the gap for
# loading(the data are totaly same in bytes for torch and enbula s
# aving).
# In this case, we must disable nebula load to use raw torch load.
# Customer can just set NEBULA_ENABLE_NEBULA_LOAD to False. Then use
# original way of deepspeed to load, i.e. set the value of "--load".
NEBULA_ENABLE_NEBULA_LOAD = "enable_nebula_load"
NEBULA_ENABLE_NEBULA_LOAD_DEFAULT = True

# When you want to resume the previous checkpoint saved by nebula,
# you can set NEBULA_LOAD_PATH as the parent folder of checkpoint.
# If NEBULA_LOAD_PATH is None, the NEBULA_PERSISTENT_STORAGE_PATH
# will be the default path to load.
NEBULA_LOAD_PATH = "nebula_load_path"
NEBULA_LOAD_PATH_DEFAULT = None

# Nebula will save the checkpoint under NEBULA_LOAD_PATH in the
# asynchronous way.
NEBULA_PERSISTENT_STORAGE_PATH = "persistent_storage_path"
NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT = None

# Time interval to trigger the nebula persistence.
NEBULA_PERSISTENT_TIME_INTERVAL = "persistent_time_interval"
NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT = 100

# Checkpoint number which will be kept in memory. Let us say,
# if the value is 2. Then we have checkpoints 1 and 2 are ready
# now. When it comes to checkpoint 3, the 1 will be removed if
# 1 has been persisted to disk.
NEBULA_NUM_OF_VERSION_IN_RETENTION = "num_of_version_in_retention"
NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT = 2

# Neubla envs
NEBULA_EXPORT_ENVS = [
'DLTS_JOB_ID',
'DLTS_NUM_WORKER',
'NEBULA_PERSISTENT_STORAGE_PATH',
'NEBULA_PERSISTENT_TIME_INTERVAL',
'AML_RUN_ID',
'AZUREML_RUN_TOKEN',
'AZUREML_WORKSPACE_SCOPE',
'AZUREML_EXPERIMENT_SCOPE',
'AZUREML_RUN_HISTORY_SERVICE_ENDPOINT',
'AZUREML_RUN_ID',
'NEBULA_MEMORY_BUFFER_SIZE',
'AZUREML_PARAMETER_ITPJOB_NAME',
'FC_TASKROLE_NAME',
'FC_TASK_INDEX',
'MASTER_HOST',
'LOCAL_HOST',
'AZUREML_BLOB_ACCOUNT_NAME',
'AZUREML_BLOB_ACCOUNT_KEY'
]

# ITP env files
DLTS_POD_ENV_PATH = '/dlts-runtime/env/pod.env'
37 changes: 37 additions & 0 deletions deepspeed/runtime/checkpoint_engine/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Checkpoint Engine


The `CheckpointEngine` was designed to modularized the checkpoint serialization. In this way, we can simply replace/refine the checkpoint serialization methods.

### Interface for `CheckpointEngine`

Basically, for checkpoint management(save/load by deepspeed with the given tag), the `CheckpointEngine` will:

1. To make preliminaries ready by call `create(tag)`. For `torch`, we can just log some extra info as `torch` can directly call `save/load` without other preparation.

2. After the `create(tag)`, deepspeed can call `save/load` to persist files into disk/memory/etc.

3. When all the files for a tag are ready, deepspeed engine will call `commit()` to tell the checkpoint engine current checkpoint is complete. For original torch, it also plays the role of logger.


```python
class CheckpointEngine(object):
# init checkpoint engine for save/load
def __init__(self, config_params=None):
pass

def create(self, tag):
# create checkpoint on give tag for save/load.
pass

def save(self, state_dict, path: str):
pass

def load(self, path: str, map_location=None):
pass

def commit(self, tag):
# to tell checkpoint services if all files are readys.
pass

```
Empty file.
19 changes: 19 additions & 0 deletions deepspeed/runtime/checkpoint_engine/checkpoint_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
class CheckpointEngine(object):

# init checkpoint engine for save/load
def __init__(self, config_params=None):
pass

def create(self, tag):
# create checkpoint on give tag for save/load.
pass

def save(self, state_dict, path: str):
pass

def load(self, path: str, map_location=None):
pass

def commit(self, tag):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a query function? What does the return value mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It plays the role to tell main process current checkpoint is saved completely. For raw torch, it will be same as following code.

        if save_latest and self.global_rank == 0:
            with open(os.path.join(save_dir, 'latest'), 'w') as fd:
                fd.write(tag)

Nebula need this signal to start the persistence processing for checkpoint.

Copy link
Contributor

@tjruwase tjruwase Jul 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This snippet of code is independent of torch.save() as it is not writing a checkpoint state to disk but rather writing a metadata file for managing checkpoints. So, I am not sure how it is similar to commit which signals the start of persistent processing.

However, this two-step persistence process of nebula is different from torch since checkpoint is persisted by the time torch.save() returns.

With ZeRO training, multiple checkpoint files are created for a single deepspeed.save_checkpoint() using multiple torch.save() with different file paths. For nebula, do we need a single commit for each checkpoint file or just one for the save_checkpoint() call?

Some documentation on this commit and two-step persistence behavior would be greatly helpful. We need to think about whether this should be exposed to client script or should be managed completely by DeepSpeed and hidden from clients.

Also, what does the return value mean?

Copy link
Contributor Author

@trajepl trajepl Jul 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
Actually for nebula.save, we have the following steps to persist a checkpoint into disk:

  1. tier1: we pin the shared memory and write the checkpoints files into shared memory(let us say under /dev/shm/). It is a synced way which will block the training.
  2. tier3: when nebule service received the commit signal from SDK, nebula service will start a persistent process which move the checkpoints from tier1 to tier3(equal to nebula_load_path_tier3 where you commented). After tier3 done, the nebula meta info will be wrote into tier3 path.

So for your questions:

For nebula, do we need a single commit for each checkpoint file or just one for the save_checkpoint() call?
Nebula needs every training process to call commit to tell the services current checkpoint can be moved to tier3. But origin torch only write the meta file for checkpoint management.

what does the return value mean?
Emmm, it is dummy return for torch I think. If needed, maybe we can move the torch meta file operation into this method only for rank0? What do you think? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I want to call out is that: the files structures are totally different for original torch path and nebula tier3 path.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine for the checkpoint file structure to be differerent between torch and nebula tier3, it simply means that torch.load() cannot be used for nebula checkpoints. This is not a concern for deepspeed.

I think the fd.write(tag) should remain inside deepspeed and outside nebula for a number of reasons: (1) it is not currently created by torch since it is not checkpoint state, but metadata, (2) It is used by deepspeed and existing clients for checkpoint management, and (3) it is very small data and fast to write. Typically, it is used to identify the latest checkpoint file in a folder of many checkpoint files, but clients can also use tags in other ways. So, can you explain why this part has to change for nebula?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to keep fd.write(tag) out of nebula. Actually we did that in code as well.

# Save latest checkpoint tag
dist.barrier()
self.checkpoint_engine.commit(tag)
if save_latest and self.global_rank == 0:
    with open(os.path.join(save_dir, 'latest'), 'w') as fd:
        fd.write(tag)

Then for:

  1. torch: the commit was used to align the API. No extra efforts for condition(if nebula else...).
  2. nebula: the commit will send the signal to nebula services to start tier3.

The fd.write(tag) is workable for torch, but cannot replace commit() to tell nebula service current checkpoint is saved.

So for you question:

why this part has to change for nebula?
We need the commit() for nebula when all files under a tag are saved. But it will just return dummy true for torch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to keep fd.write(tag) out of nebula. Actually we did that in code as well.

Sorry, I misread the PR. You are correct that you did not change this code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then for:

  1. torch: the commit was used to align the API. No extra efforts for condition(if nebula else...).
  2. nebula: the commit will send the signal to nebula services to start tier3.

The fd.write(tag) is workable for torch, but cannot replace commit() to tell nebula service current checkpoint is saved.

So for you question:

why this part has to change for nebula?
We need the commit() for nebula when all files under a tag are saved. But it will just return dummy true for torch.

So, this brings up the importance of clarifying the semantics of nebula and other checkpoint engines. I don't know if there is standard terminology for these concepts. But it seems to me that one key question is whether persistence is implicit or explicit. Persistence with torch.save() is implicit because checkpoint files are persisted on disk on return, while nebula is explicit because an additional call is required to make it persist the files. I think this is an important consideration for designing the general CheckpointEngine abstraction. We can discuss this further offline.

# to tell checkpoint services if all files are readys.
pass
88 changes: 88 additions & 0 deletions deepspeed/runtime/checkpoint_engine/nebula_checkpoint_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import torch
import torch_nebula

from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
CheckpointEngine
from deepspeed.utils import logger, log_dist
from deepspeed.nebula.constants import *


def _get_tag_from_path(path):
return os.path.basename(os.path.dirname(path))


class NebulaCheckpointEngine(CheckpointEngine):
def __init__(self, config_params=None):
super().__init__(config_params)
self.checkpoint = None
self.tag_flag = None
self.enable_nebula_load = config_params.enable_nebula_load
self.nebula_load_path = config_params.load_path
if self.nebula_load_path is None:
self.nebula_load_path = config_params.persistent_storage_path

nebula_config_params = {
NEBULA_PERSISTENT_STORAGE_PATH: config_params.persistent_storage_path,
NEBULA_PERSISTENT_TIME_INTERVAL: config_params.persistent_time_interval,
NEBULA_NUM_OF_VERSION_IN_RETENTION:
config_params.num_of_version_in_retention,
}
torch_nebula.init(**nebula_config_params)

def create(self, tag):
log_dist(f"[Nebula] Start Checkpoint for tag:{tag}", ranks=[0])
# -2 means: customer needs to explicitly tell nebula
# current checkpoint is complete by commit methond.
self.checkpoint = torch_nebula.Checkpoint(tag, -2)

def save(self, state_dict, path: str):
tag = _get_tag_from_path(path)
partititon_name = os.path.basename(path)
logger.info(f"[Nebula] Saving {partititon_name} under tag{tag}...")
self.checkpoint.save(partititon_name, state_dict)
logger.info(f"[Nebula] Saved {partititon_name} under tag{tag}.")
return None

def load(self, path: str, map_location=None):
tag = _get_tag_from_path(path)
first_load_flag = self.tag_flag is None or self.tag_flag == tag
if not self.enable_nebula_load and first_load_flag:
self.tag_flag = tag
logger.info(
f"[Nebula] Disable nebula load. Loading checkpoint from {path}...")
partition = torch.load(path, map_location=map_location)
logger.info(
f"[Nebula] Disable nebula load. Loaded checkpoint from {path}...")
return partition

partititon_name = os.path.basename(path)
logger.info(
f"[Nebula] Loading {path} under tag{tag} from {self.nebula_load_path}...")

checkpoint = None
if tag is None:
checkpoint = torch_nebula.get_latest_checkpoint(
persist_path=self.nebula_load_path)
if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
logger.warning(f"Unable to find latest valid checkpoint from Nebula!")
return None
else:
checkpoint = torch_nebula.get_checkpoint(tag=tag,
persist_path=self.nebula_load_path)
partition = checkpoint.load(partititon_name, map_location=map_location)
logger.info(
f"[Nebula] Loaded {path} under tag{tag} from {self.nebula_load_path}.")
return partition

def commit(self, tag):
# nebula commit will be call when all files under give tag are ready to be persisted in the async way.
logger.info(
f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting"
)
commit_rls = self.checkpoint.commit()
if not commit_rls:
logger.error(
f"[Nebula] failed to commit the checkpoint, please check the log.")
return False
return commit_rls
Loading