Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
4b87f30
Refactor universal checkpointing and tensor fragments
tjruwase Aug 22, 2022
4317b84
Merge branch 'master' into olruwase/refactor_universal_checkpoint
tjruwase Aug 23, 2022
dfc816d
Formatting
tjruwase Aug 23, 2022
21aa55a
Merge branch 'master' into olruwase/refactor_universal_checkpoint
tjruwase Aug 23, 2022
c683891
Merge branch 'master' into olruwase/refactor_universal_checkpoint
tjruwase Aug 24, 2022
89df0b3
Merge branch 'master' into olruwase/refactor_universal_checkpoint
tjruwase Aug 24, 2022
622e7ab
Merge branch 'master' into olruwase/refactor_universal_checkpoint
tjruwase Aug 25, 2022
115fe42
Merge branch 'master' into olruwase/refactor_universal_checkpoint
tjruwase Aug 25, 2022
3c09ada
Merge branch 'master' into olruwase/refactor_universal_checkpoint
tjruwase Aug 29, 2022
d40b923
Merge branch 'master' into olruwase/refactor_universal_checkpoint
tjruwase Aug 29, 2022
7cf0235
Support zero stage1; Expand TP dim
tjruwase Sep 1, 2022
ece4ce3
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 1, 2022
cae2172
Remove debug prints
tjruwase Sep 1, 2022
48b62c2
Merge branch 'olruwase/zero_1_2_universal_ckpt' of github.com:microso…
tjruwase Sep 1, 2022
1ace025
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 1, 2022
529f2d8
Detect sharded optimizer state
tjruwase Sep 2, 2022
93246ec
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 3, 2022
4532059
Merge master
tjruwase Sep 3, 2022
024baa8
Format fixes
tjruwase Sep 3, 2022
a2f592d
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 5, 2022
697287d
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 7, 2022
a5e9900
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 14, 2022
bfefdec
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 14, 2022
b78c5f6
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 19, 2022
e346529
Encode reshaping guide
tjruwase Sep 19, 2022
1447fc2
Merge branch 'olruwase/zero_1_2_universal_ckpt' of github.com:microso…
tjruwase Sep 19, 2022
1bebe2e
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 19, 2022
a5afb81
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 22, 2022
bb85c69
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 26, 2022
c929f89
More symbolic constants
tjruwase Sep 26, 2022
6ba5ad4
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 26, 2022
83ecf19
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 26, 2022
0f1738d
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Sep 28, 2022
6c6823f
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Oct 3, 2022
16d26b9
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Oct 4, 2022
a26458a
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Oct 7, 2022
48d291a
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Oct 9, 2022
e358ee6
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
mrwyattii Oct 11, 2022
4f51d0a
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Oct 13, 2022
f30ad0f
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Oct 13, 2022
88bf045
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Oct 14, 2022
329cb82
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Oct 16, 2022
0d9dc10
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Oct 17, 2022
37485ae
Merge branch 'master' into olruwase/zero_1_2_universal_ckpt
tjruwase Oct 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepspeed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
from .zero_checkpoint import ZeROCheckpoint

from .universal_checkpoint import enable_universal_checkpoint

from .constants import *
31 changes: 26 additions & 5 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
PARTITION_COUNT = 'partition_count'
ZERO_STAGE = 'zero_stage'
CLIP_GRAD = 'clip_grad'
PARAM_SLICE_MAPPINGS = 'param_slice_mappings'
FP32_WEIGHT_KEY = "fp32"

#########################################
Expand All @@ -24,20 +23,42 @@
PARAM = 'param'
PARAM_SHAPES = 'param_shapes'
BUFFER_NAMES = 'buffer_names'
VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor'
CAT_DIM = "cat_dim"

#########################################
# Checkpoint naming constants
#########################################
MODEL_FILE_PREFIX = 'mp_rank_'
ZERO_FILE_PREFIX = 'bf16_' + 'zero_pp_rank_'
ZERO_FILE_PREFIX = 'zero_pp_rank_'
OPTIM_FILE_SUFFIX = '_optim_states.pt'
MODEL_FILE_SUFFIX = '_model_states.pt'
LAYER_FILE_PREFIX = 'layer_'
BF16_ZERO_FILE_PREFIX = ZERO_FILE_PREFIX
BF16_ZERO_FILE_PREFIX = 'bf16_' + ZERO_FILE_PREFIX
FP16_ZERO_FILE_PREFIX = 'fp16_' + ZERO_FILE_PREFIX

#########################################
# Checkpoint utility keys
#########################################
DS_VERSION = 'ds_version'

#########################################
# Universal Checkpoint keys
#########################################
UNIVERSAL_CHECKPOINT_INFO = 'universal_checkpoint_info'
UNIVERSAL_CHECKPOINT_VERSION_KEY = 'universal_checkpoint_version'
# Reserve version 0.1 for the hardcoded logic used in BLOOM-176B training
UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2

# Vocabulary padding
VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor'
PADDED_VOCAB_SIZE = 'padded_vocab_size'
ORIGINAL_VOCAB_SIZE = 'original_vocab_size'

# Parameter splitting/merging
PARAM_SLICE_MAPPINGS = 'param_slice_mappings'
CAT_DIM = "cat_dim"

# Regex list of parameters that require special handling
VOCABULARY_PARAMETER_PATTERNS = 'vocabulary_parameter_patterns'
PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns'
PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns'
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns'
55 changes: 27 additions & 28 deletions deepspeed/checkpoint/deepspeed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
get_files,
get_files_with_prefix)

from .constants import (ZERO_FILE_PREFIX, MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)

from .reshape_meg_2d import reshape_meg_2d_parallel, meg_2d_parallel_map
from .zero_checkpoint import ZeROCheckpoint
Expand Down Expand Up @@ -39,37 +39,36 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self.dir = dir
self._validate_folder(dir)

self.zero_checkpoint = ZeROCheckpoint(dir)

self.file_list = get_files(dir)
self.zero_files = get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX)
self.layer_files = get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX)
self.mp_rank_files = get_files_with_prefix(self.file_list, MODEL_FILE_PREFIX)

self.layer_keys = self._get_layer_keys()
self.layer_count = len(self.layer_keys)
self.original_tp_degree = len(
get_files_with_prefix(self.layer_files,
f'{LAYER_FILE_PREFIX}01'))
self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree
self.original_dp_degree = max(
1,
len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree))

self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree
self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree
self.dp_degree = self.original_dp_degree if dp_degree is None else dp_degree

self.original_world_size = self.original_tp_degree * self.original_pp_degree * self.original_dp_degree

self.tp_degree = self.zero_checkpoint.get_src_tp_degree(
) if tp_degree is None else tp_degree
self.pp_degree = self.zero_checkpoint.get_src_pp_degree(
) if pp_degree is None else pp_degree
self.dp_degree = self.zero_checkpoint.get_src_dp_degree(
) if dp_degree is None else dp_degree

self.original_world_size = self.zero_checkpoint.get_src_tp_degree(
) * self.zero_checkpoint.get_src_pp_degree(
) * self.zero_checkpoint.get_src_dp_degree()
self.world_size = self.tp_degree * self.pp_degree * self.dp_degree

self.old_2d_map = meg_2d_parallel_map(self.original_pp_degree,
self.original_tp_degree)
self.old_2d_map = meg_2d_parallel_map(self.zero_checkpoint.get_src_pp_degree(),
self.zero_checkpoint.get_src_tp_degree())
self.old_2d_map.simple_init()
self.new_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.original_pp_degree,
old_tp_degree=self.original_tp_degree,
new_pp_degree=self.pp_degree,
new_tp_degree=self.tp_degree)
self.new_2d_map = reshape_meg_2d_parallel(
old_pp_degree=self.zero_checkpoint.get_src_pp_degree(),
old_tp_degree=self.zero_checkpoint.get_src_tp_degree(),
new_pp_degree=self.pp_degree,
new_tp_degree=self.tp_degree)

self.zero_checkpoint = ZeROCheckpoint(dir)
if self.is_change_pp_degree() or self.is_change_tp_degree(
) or self.is_change_dp_degree():
self.zero_checkpoint.reshape(
Expand All @@ -88,13 +87,13 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None):
self._build_global_state()

def is_change_tp_degree(self):
return self.tp_degree != self.original_tp_degree
return self.tp_degree != self.zero_checkpoint.get_src_tp_degree()

def is_change_pp_degree(self):
return self.pp_degree != self.original_pp_degree
return self.pp_degree != self.zero_checkpoint.get_src_pp_degree()

def is_change_dp_degree(self):
return self.dp_degree != self.original_dp_degree
return self.dp_degree != self.zero_checkpoint.get_src_dp_degree()

def show_2d_mapping(self):
print(f'reshaped 2d map ---- begin')
Expand Down Expand Up @@ -171,8 +170,8 @@ def _get_checkpoint_value(self, key):
def get_args(self):
return self._get_checkpoint_value(ARGS_KEY)

def get_checkpoint_info(self):
return self._get_checkpoint_value(CHECKPOINT_INFO_KEY)
def get_checkpoint_info(self, info_key=CHECKPOINT_INFO_KEY):
return self._get_checkpoint_value(info_key)

def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict:
assert tp_index < self.tp_degree
Expand Down Expand Up @@ -272,8 +271,8 @@ def _build_transformer_file_map(self):

def _sanity_check(self):
assert len(self.mp_rank_files) % self.tp_degree == 0
assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0
assert len(self.layer_keys) > 2
assert self.zero_checkpoint.num_files % (self.pp_degree * self.tp_degree) == 0
# XXX: fix me - isn't always the case
# only true with --pp-partition-method 'type:transformer|embedding' \
# assert (len(self.layer_keys) - 2) % self.pp_degree == 0
Expand Down
25 changes: 19 additions & 6 deletions deepspeed/checkpoint/reshape_3d_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .reshape_utils import (get_files, get_files_with_prefix, partition_data)
from .reshape_utils import (get_files,
get_files_with_prefix,
partition_data,
get_zero_files)

from .constants import (ZERO_FILE_PREFIX, MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)
from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX)

from .reshape_meg_2d import (reshape_meg_2d_parallel, meg_2d_parallel_map)

Expand Down Expand Up @@ -34,6 +37,9 @@ def reshape(self, target_3d_desc, verbose=False):
def get_desc(self):
return f'{PP_DIM},{TP_DIM},{DP_DIM} = ({self.pp_degree}, {self.tp_degree}, {self.dp_degree})'

def world_size(self):
return self.pp_degree * self.tp_degree * self.dp_degree

def is_valid(self, pp_index, tp_index, dp_index):
err_msg = []
valid = True
Expand Down Expand Up @@ -70,10 +76,17 @@ def can_reshape(self, target_3d_desc):

def get_model_3d_descriptor(dir):
file_list = get_files(dir)
tp_degree = len(get_files_with_prefix(file_list, f'{LAYER_FILE_PREFIX}01'))
pp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) // tp_degree
num_zero_files = len(get_files_with_prefix(file_list, ZERO_FILE_PREFIX))
dp_degree = max(1, num_zero_files // (pp_degree * tp_degree))
zero_file_list = get_zero_files(dir)
num_pp0_files = len(get_files_with_prefix(file_list, f'{LAYER_FILE_PREFIX}01'))
if num_pp0_files > 0:
tp_degree = num_pp0_files
pp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) // tp_degree
dp_degree = max(1, len(zero_file_list) // (pp_degree * tp_degree))
else:
tp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX))
dp_degree = max(1, len(zero_file_list) // tp_degree)
pp_degree = 0

return model_3d_desc(pp_degree, tp_degree, dp_degree)


Expand Down
11 changes: 11 additions & 0 deletions deepspeed/checkpoint/reshape_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
from collections import OrderedDict
from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX)


def basic_folder_validation(dir):
Expand Down Expand Up @@ -32,6 +33,16 @@ def get_files(dir):
return file_list


def get_zero_files(dir):
file_list = get_files(dir)
for prefix in [ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX]:
zero_files = get_files_with_prefix(file_list, prefix)
if len(zero_files) > 0:
return zero_files

return []


def partition_data(data_list, num_partitions):
num_elems = len(data_list)
assert num_elems % num_partitions == 0
Expand Down
8 changes: 3 additions & 5 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import torch
import types

from .constants import (FP32_WEIGHT_KEY,
PARAM,
VOCAB_DIVISIBILITY_PADDING_TENSOR,
Expand Down Expand Up @@ -54,18 +53,17 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
padded_target_vocab_size = self.shape[0] * tp_world_size
if padded_target_vocab_size > full_hp_param.shape[0]:
# Need to expand
padding_tensor = vocab_divisibility_padding_tensor.expand(
padded_target_vocab_size - full_hp_param.shape[0])
padding_size = padded_target_vocab_size - full_hp_param.shape[0]
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param = torch.nn.functional.pad(full_hp_param,
(0,
0,
0,
padding_tensor.shape[0]),
padding_size),
"constant",
0)
full_hp_param[:-padding_tensor.shape[0], :] = padding_tensor
full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor
else:
# Need to shrink or keep the same
full_hp_param = full_hp_param[:padded_target_vocab_size, :]
Expand Down
30 changes: 15 additions & 15 deletions deepspeed/checkpoint/zero_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,9 @@
from .constants import (BASE_OPTIMIZER_STATE,
GROUP_PADDINGS,
OPTIMIZER_STATE_DICT,
PARTITION_COUNT,
ZERO_FILE_PREFIX,
BF16_ZERO_FILE_PREFIX)
PARTITION_COUNT)

from .reshape_utils import (basic_folder_validation,
get_files,
get_files_with_prefix,
merge_state)
from .reshape_utils import (basic_folder_validation, get_zero_files, merge_state)

from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor)

Expand All @@ -21,7 +16,7 @@ class ZeROCheckpoint(object):
def __init__(self, dir):
basic_folder_validation(dir)
self.dir = dir
self.file_list = self._get_zero_files(dir)
self.file_list = get_zero_files(dir)
self.num_files = len(self.file_list)
assert self.num_files > 0, f'No ZeRO files found in {dir}'

Expand All @@ -31,6 +26,18 @@ def __init__(self, dir):
dp_degree=self.src_3d.dp_degree)
self._3d_file_map = self.src_3d.reshape(self.target_3d)

def get_src_world_size(self):
return self.src_3d.world_size()

def get_src_tp_degree(self):
return self.src_3d.tp_degree

def get_src_pp_degree(self):
return self.src_3d.pp_degree

def get_src_dp_degree(self):
return self.src_3d.dp_degree

def get_file_indices_for_rank(self, pp_index, tp_index, dp_index):
assert dp_index < len(self._3d_file_map), f'DP index {dp_index} >= DP degree {len(self._3d_file_map)}'
dp_2d_map = self._3d_file_map[dp_index]
Expand Down Expand Up @@ -137,10 +144,3 @@ def _update_partition_count(self, sd):
num_groups = len(partition_counts)
sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree
] * num_groups

def _get_zero_files(self, dir):
file_list = get_files(dir)
zero_files = get_files_with_prefix(file_list, ZERO_FILE_PREFIX)
if len(zero_files) > 0:
return zero_files
return get_files_with_prefix(file_list, BF16_ZERO_FILE_PREFIX)
Loading