diff --git a/deepspeed/checkpoint/__init__.py b/deepspeed/checkpoint/__init__.py index 407a9b50a7bb..1a7a3f18b27d 100644 --- a/deepspeed/checkpoint/__init__.py +++ b/deepspeed/checkpoint/__init__.py @@ -13,3 +13,5 @@ from .zero_checkpoint import ZeROCheckpoint from .universal_checkpoint import enable_universal_checkpoint + +from .constants import * diff --git a/deepspeed/checkpoint/constants.py b/deepspeed/checkpoint/constants.py index b46502ceae36..7fe8afaec745 100644 --- a/deepspeed/checkpoint/constants.py +++ b/deepspeed/checkpoint/constants.py @@ -15,7 +15,6 @@ PARTITION_COUNT = 'partition_count' ZERO_STAGE = 'zero_stage' CLIP_GRAD = 'clip_grad' -PARAM_SLICE_MAPPINGS = 'param_slice_mappings' FP32_WEIGHT_KEY = "fp32" ######################################### @@ -24,20 +23,42 @@ PARAM = 'param' PARAM_SHAPES = 'param_shapes' BUFFER_NAMES = 'buffer_names' -VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor' -CAT_DIM = "cat_dim" ######################################### # Checkpoint naming constants ######################################### MODEL_FILE_PREFIX = 'mp_rank_' -ZERO_FILE_PREFIX = 'bf16_' + 'zero_pp_rank_' +ZERO_FILE_PREFIX = 'zero_pp_rank_' OPTIM_FILE_SUFFIX = '_optim_states.pt' MODEL_FILE_SUFFIX = '_model_states.pt' LAYER_FILE_PREFIX = 'layer_' -BF16_ZERO_FILE_PREFIX = ZERO_FILE_PREFIX +BF16_ZERO_FILE_PREFIX = 'bf16_' + ZERO_FILE_PREFIX +FP16_ZERO_FILE_PREFIX = 'fp16_' + ZERO_FILE_PREFIX ######################################### # Checkpoint utility keys ######################################### DS_VERSION = 'ds_version' + +######################################### +# Universal Checkpoint keys +######################################### +UNIVERSAL_CHECKPOINT_INFO = 'universal_checkpoint_info' +UNIVERSAL_CHECKPOINT_VERSION_KEY = 'universal_checkpoint_version' +# Reserve version 0.1 for the hardcoded logic used in BLOOM-176B training +UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2 + +# Vocabulary padding +VOCAB_DIVISIBILITY_PADDING_TENSOR = 'vocab_divisibility_padding_tensor' +PADDED_VOCAB_SIZE = 'padded_vocab_size' +ORIGINAL_VOCAB_SIZE = 'original_vocab_size' + +# Parameter splitting/merging +PARAM_SLICE_MAPPINGS = 'param_slice_mappings' +CAT_DIM = "cat_dim" + +# Regex list of parameters that require special handling +VOCABULARY_PARAMETER_PATTERNS = 'vocabulary_parameter_patterns' +PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns' +PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns' +PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns' diff --git a/deepspeed/checkpoint/deepspeed_checkpoint.py b/deepspeed/checkpoint/deepspeed_checkpoint.py index 4b8d31e832d7..fa8f80af4c11 100644 --- a/deepspeed/checkpoint/deepspeed_checkpoint.py +++ b/deepspeed/checkpoint/deepspeed_checkpoint.py @@ -9,7 +9,7 @@ get_files, get_files_with_prefix) -from .constants import (ZERO_FILE_PREFIX, MODEL_FILE_PREFIX, LAYER_FILE_PREFIX) +from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX) from .reshape_meg_2d import reshape_meg_2d_parallel, meg_2d_parallel_map from .zero_checkpoint import ZeROCheckpoint @@ -39,37 +39,36 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): self.dir = dir self._validate_folder(dir) + self.zero_checkpoint = ZeROCheckpoint(dir) + self.file_list = get_files(dir) - self.zero_files = get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX) self.layer_files = get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX) self.mp_rank_files = get_files_with_prefix(self.file_list, MODEL_FILE_PREFIX) self.layer_keys = self._get_layer_keys() self.layer_count = len(self.layer_keys) - self.original_tp_degree = len( - get_files_with_prefix(self.layer_files, - f'{LAYER_FILE_PREFIX}01')) - self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree - self.original_dp_degree = max( - 1, - len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree)) - - self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree - self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree - self.dp_degree = self.original_dp_degree if dp_degree is None else dp_degree - - self.original_world_size = self.original_tp_degree * self.original_pp_degree * self.original_dp_degree + + self.tp_degree = self.zero_checkpoint.get_src_tp_degree( + ) if tp_degree is None else tp_degree + self.pp_degree = self.zero_checkpoint.get_src_pp_degree( + ) if pp_degree is None else pp_degree + self.dp_degree = self.zero_checkpoint.get_src_dp_degree( + ) if dp_degree is None else dp_degree + + self.original_world_size = self.zero_checkpoint.get_src_tp_degree( + ) * self.zero_checkpoint.get_src_pp_degree( + ) * self.zero_checkpoint.get_src_dp_degree() self.world_size = self.tp_degree * self.pp_degree * self.dp_degree - self.old_2d_map = meg_2d_parallel_map(self.original_pp_degree, - self.original_tp_degree) + self.old_2d_map = meg_2d_parallel_map(self.zero_checkpoint.get_src_pp_degree(), + self.zero_checkpoint.get_src_tp_degree()) self.old_2d_map.simple_init() - self.new_2d_map = reshape_meg_2d_parallel(old_pp_degree=self.original_pp_degree, - old_tp_degree=self.original_tp_degree, - new_pp_degree=self.pp_degree, - new_tp_degree=self.tp_degree) + self.new_2d_map = reshape_meg_2d_parallel( + old_pp_degree=self.zero_checkpoint.get_src_pp_degree(), + old_tp_degree=self.zero_checkpoint.get_src_tp_degree(), + new_pp_degree=self.pp_degree, + new_tp_degree=self.tp_degree) - self.zero_checkpoint = ZeROCheckpoint(dir) if self.is_change_pp_degree() or self.is_change_tp_degree( ) or self.is_change_dp_degree(): self.zero_checkpoint.reshape( @@ -88,13 +87,13 @@ def __init__(self, dir, tp_degree=None, pp_degree=None, dp_degree=None): self._build_global_state() def is_change_tp_degree(self): - return self.tp_degree != self.original_tp_degree + return self.tp_degree != self.zero_checkpoint.get_src_tp_degree() def is_change_pp_degree(self): - return self.pp_degree != self.original_pp_degree + return self.pp_degree != self.zero_checkpoint.get_src_pp_degree() def is_change_dp_degree(self): - return self.dp_degree != self.original_dp_degree + return self.dp_degree != self.zero_checkpoint.get_src_dp_degree() def show_2d_mapping(self): print(f'reshaped 2d map ---- begin') @@ -171,8 +170,8 @@ def _get_checkpoint_value(self, key): def get_args(self): return self._get_checkpoint_value(ARGS_KEY) - def get_checkpoint_info(self): - return self._get_checkpoint_value(CHECKPOINT_INFO_KEY) + def get_checkpoint_info(self, info_key=CHECKPOINT_INFO_KEY): + return self._get_checkpoint_value(info_key) def get_2d_parallel_state(self, tp_index: int, pp_index: int) -> dict: assert tp_index < self.tp_degree @@ -272,8 +271,8 @@ def _build_transformer_file_map(self): def _sanity_check(self): assert len(self.mp_rank_files) % self.tp_degree == 0 - assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0 assert len(self.layer_keys) > 2 + assert self.zero_checkpoint.num_files % (self.pp_degree * self.tp_degree) == 0 # XXX: fix me - isn't always the case # only true with --pp-partition-method 'type:transformer|embedding' \ # assert (len(self.layer_keys) - 2) % self.pp_degree == 0 diff --git a/deepspeed/checkpoint/reshape_3d_utils.py b/deepspeed/checkpoint/reshape_3d_utils.py index b625eb222589..558f4b65972a 100644 --- a/deepspeed/checkpoint/reshape_3d_utils.py +++ b/deepspeed/checkpoint/reshape_3d_utils.py @@ -1,6 +1,9 @@ -from .reshape_utils import (get_files, get_files_with_prefix, partition_data) +from .reshape_utils import (get_files, + get_files_with_prefix, + partition_data, + get_zero_files) -from .constants import (ZERO_FILE_PREFIX, MODEL_FILE_PREFIX, LAYER_FILE_PREFIX) +from .constants import (MODEL_FILE_PREFIX, LAYER_FILE_PREFIX) from .reshape_meg_2d import (reshape_meg_2d_parallel, meg_2d_parallel_map) @@ -34,6 +37,9 @@ def reshape(self, target_3d_desc, verbose=False): def get_desc(self): return f'{PP_DIM},{TP_DIM},{DP_DIM} = ({self.pp_degree}, {self.tp_degree}, {self.dp_degree})' + def world_size(self): + return self.pp_degree * self.tp_degree * self.dp_degree + def is_valid(self, pp_index, tp_index, dp_index): err_msg = [] valid = True @@ -70,10 +76,17 @@ def can_reshape(self, target_3d_desc): def get_model_3d_descriptor(dir): file_list = get_files(dir) - tp_degree = len(get_files_with_prefix(file_list, f'{LAYER_FILE_PREFIX}01')) - pp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) // tp_degree - num_zero_files = len(get_files_with_prefix(file_list, ZERO_FILE_PREFIX)) - dp_degree = max(1, num_zero_files // (pp_degree * tp_degree)) + zero_file_list = get_zero_files(dir) + num_pp0_files = len(get_files_with_prefix(file_list, f'{LAYER_FILE_PREFIX}01')) + if num_pp0_files > 0: + tp_degree = num_pp0_files + pp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) // tp_degree + dp_degree = max(1, len(zero_file_list) // (pp_degree * tp_degree)) + else: + tp_degree = len(get_files_with_prefix(file_list, MODEL_FILE_PREFIX)) + dp_degree = max(1, len(zero_file_list) // tp_degree) + pp_degree = 0 + return model_3d_desc(pp_degree, tp_degree, dp_degree) diff --git a/deepspeed/checkpoint/reshape_utils.py b/deepspeed/checkpoint/reshape_utils.py index 5c3a687967be..0d1a7286af6b 100644 --- a/deepspeed/checkpoint/reshape_utils.py +++ b/deepspeed/checkpoint/reshape_utils.py @@ -1,6 +1,7 @@ import os import torch from collections import OrderedDict +from .constants import (ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX) def basic_folder_validation(dir): @@ -32,6 +33,16 @@ def get_files(dir): return file_list +def get_zero_files(dir): + file_list = get_files(dir) + for prefix in [ZERO_FILE_PREFIX, FP16_ZERO_FILE_PREFIX, BF16_ZERO_FILE_PREFIX]: + zero_files = get_files_with_prefix(file_list, prefix) + if len(zero_files) > 0: + return zero_files + + return [] + + def partition_data(data_list, num_partitions): num_elems = len(data_list) assert num_elems % num_partitions == 0 diff --git a/deepspeed/checkpoint/universal_checkpoint.py b/deepspeed/checkpoint/universal_checkpoint.py index f791dec6afa4..b58de4871031 100644 --- a/deepspeed/checkpoint/universal_checkpoint.py +++ b/deepspeed/checkpoint/universal_checkpoint.py @@ -4,7 +4,6 @@ import os import torch import types - from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_DIVISIBILITY_PADDING_TENSOR, @@ -54,18 +53,17 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size): padded_target_vocab_size = self.shape[0] * tp_world_size if padded_target_vocab_size > full_hp_param.shape[0]: # Need to expand - padding_tensor = vocab_divisibility_padding_tensor.expand( - padded_target_vocab_size - full_hp_param.shape[0]) + padding_size = padded_target_vocab_size - full_hp_param.shape[0] # Implement the following concat in efficient way using pad #full_hp_param = torch.cat((full_hp_param, padding_tensor), 0) full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, - padding_tensor.shape[0]), + padding_size), "constant", 0) - full_hp_param[:-padding_tensor.shape[0], :] = padding_tensor + full_hp_param[:-padding_size, :] = vocab_divisibility_padding_tensor else: # Need to shrink or keep the same full_hp_param = full_hp_param[:padded_target_vocab_size, :] diff --git a/deepspeed/checkpoint/zero_checkpoint.py b/deepspeed/checkpoint/zero_checkpoint.py index 01a6ebe9c1d9..57de3fd5be26 100644 --- a/deepspeed/checkpoint/zero_checkpoint.py +++ b/deepspeed/checkpoint/zero_checkpoint.py @@ -3,14 +3,9 @@ from .constants import (BASE_OPTIMIZER_STATE, GROUP_PADDINGS, OPTIMIZER_STATE_DICT, - PARTITION_COUNT, - ZERO_FILE_PREFIX, - BF16_ZERO_FILE_PREFIX) + PARTITION_COUNT) -from .reshape_utils import (basic_folder_validation, - get_files, - get_files_with_prefix, - merge_state) +from .reshape_utils import (basic_folder_validation, get_zero_files, merge_state) from .reshape_3d_utils import (model_3d_desc, get_model_3d_descriptor) @@ -21,7 +16,7 @@ class ZeROCheckpoint(object): def __init__(self, dir): basic_folder_validation(dir) self.dir = dir - self.file_list = self._get_zero_files(dir) + self.file_list = get_zero_files(dir) self.num_files = len(self.file_list) assert self.num_files > 0, f'No ZeRO files found in {dir}' @@ -31,6 +26,18 @@ def __init__(self, dir): dp_degree=self.src_3d.dp_degree) self._3d_file_map = self.src_3d.reshape(self.target_3d) + def get_src_world_size(self): + return self.src_3d.world_size() + + def get_src_tp_degree(self): + return self.src_3d.tp_degree + + def get_src_pp_degree(self): + return self.src_3d.pp_degree + + def get_src_dp_degree(self): + return self.src_3d.dp_degree + def get_file_indices_for_rank(self, pp_index, tp_index, dp_index): assert dp_index < len(self._3d_file_map), f'DP index {dp_index} >= DP degree {len(self._3d_file_map)}' dp_2d_map = self._3d_file_map[dp_index] @@ -137,10 +144,3 @@ def _update_partition_count(self, sd): num_groups = len(partition_counts) sd[OPTIMIZER_STATE_DICT][PARTITION_COUNT] = [self.target_3d.dp_degree ] * num_groups - - def _get_zero_files(self, dir): - file_list = get_files(dir) - zero_files = get_files_with_prefix(file_list, ZERO_FILE_PREFIX) - if len(zero_files) > 0: - return zero_files - return get_files_with_prefix(file_list, BF16_ZERO_FILE_PREFIX) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 216dbc35c9fd..49c51fa7c91f 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -456,86 +456,3 @@ def _get_padded_tensor(src_tensor, size): slice_tensor = torch.narrow(padded_tensor, 0, 0, src_tensor.numel()) slice_tensor.data.copy_(src_tensor.data) return padded_tensor - - -''' -Logic for lp_param to hp_param mapping - -lp lp0 lp1 lp2 lp3 lp4 <------- indices/names -lp [ ][ ][ ][ ][ ] <-------- tensors -flat_lp [ ] <-------- flat lp params -flat_hp [ ] <------------------ flat hp partition on current rank -full_hp [ ] <------- full flat hp params - - -lp2 - full numel = 16 - lp_frag - numel = 12 - frag_start = 3 - frag_end = 15 - hp_frag - numel = 12 - frag_start = 0 - frag_end = 11 - - hp_frag.copy_(lp_frag) - - -lp3: - full numel = 4 - lp_frag - numel = 4 - start = 0 - end = 3 - hp_frag - numel = 4 - start = 12 - end = 15 - - -lp4: - full numel = 12 - lp_frag - numel = 4 - start = 0 - end = 3 - hp_frag - numel = 4 - start = 16 - end = 19 - - - -Visual depiction of above -lp { } -flat_lp [ ] -flat_hp ( ) - - -flat_lp [ { ( } ) ] - lx hx ly hy - ly-hx - - -lp { } -flat_lp [ ] -flat_hp ( ) - - -flat_lp [ ( { ) } ] - hx lx hy ly - hy-lx - -lp { } -flat_lp [ ] -flat_hp ( ) - - -flat_lp [ ( { } ) ] - hx lx ly hy - ly-lx - -lp -> (lx, hy) -flat_hp -> (hx, hy) -''' diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 28a38eeac208..1f4331239f3b 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1400,6 +1400,7 @@ def _configure_zero_optimizer(self, optimizer): overlap_comm = False optimizer = DeepSpeedZeroOptimizer( optimizer, + self.param_names, timers=timers, static_loss_scale=self.loss_scale(), dynamic_loss_scale=self.dynamic_loss_scale(), diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index c3511e2819f9..c2c079e386e6 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -3,9 +3,11 @@ ''' import torch +import os from deepspeed import comm as dist from torch._six import inf from packaging import version as pkg_version +from collections import OrderedDict from deepspeed.runtime import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler @@ -31,7 +33,10 @@ SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, CLIP_GRAD, - ZERO_STAGE) + ZERO_STAGE, + PARAM_SLICE_MAPPINGS) +from deepspeed.utils import link_hp_params +from deepspeed.checkpoint import enable_universal_checkpoint # Toggle this to true to enable correctness test # with gradient partitioning and without @@ -103,6 +108,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer): """ def __init__(self, init_optimizer, + param_names, timers, static_loss_scale=1.0, dynamic_loss_scale=False, @@ -140,7 +146,8 @@ def __init__(self, # 2. keep common stuff here in case we need to add ne552w fused optimizer later self.elastic_checkpoint = elastic_checkpoint - + self.param_names = param_names + self.mpu = mpu # differences from apex.fp16_utils: # - assume all model params in fp16 # - assume all params requires grad @@ -518,6 +525,42 @@ def __init__(self, if dist.get_rank(group=self.dp_process_group) == 0: see_memory_usage(f"After initializing ZeRO optimizer", force=True) + self._link_all_hp_params() + self._enable_universal_checkpoint() + self._param_slice_mappings = self._create_param_mapping() + + def _enable_universal_checkpoint(self): + for lp_param_group in self.bit16_groups: + enable_universal_checkpoint(param_list=lp_param_group) + + def _create_param_mapping(self): + param_mapping = [] + for i, _ in enumerate(self.optimizer.param_groups): + param_mapping_per_group = OrderedDict() + for lp in self.bit16_groups[i]: + if lp._hp_mapping is not None: + lp_name = self.param_names[lp] + param_mapping_per_group[ + lp_name] = lp._hp_mapping.get_hp_fragment_address() + param_mapping.append(param_mapping_per_group) + + return param_mapping + + def _link_all_hp_params(self): + dp_world_size = dist.get_world_size(group=self.dp_process_group) + for i, _ in enumerate(self.optimizer.param_groups): + # Link bit16 and fp32 params in partition + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + partition_size = self.bit16_groups_flat[i].numel() // dp_world_size + flat_hp_partition = self.single_partition_of_fp32_groups[i] + link_hp_params( + lp_param_list=self.bit16_groups[i], + flat_hp_partition=flat_hp_partition, + partition_start=partition_id * partition_size, + partition_size=partition_size, + partition_optimizer_state=self.optimizer.state[flat_hp_partition], + dp_group=self.real_dp_process_group[i]) + def is_moe_group(self, group): return 'moe' in group and group['moe'] @@ -1826,6 +1869,21 @@ def step(self, closure=None): return + @torch.no_grad() + def update_lp_params(self): + for i, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): + partition_id = dist.get_rank(group=self.real_dp_process_group[i]) + bit16_partitions[partition_id].data.copy_(fp32_partition.data) + # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) + # if i == 0: + # print_rank_0(f'{fp32_partition[:10]=}', force=True) + + all_gather_dp_groups( + partitioned_param_groups=self.parallel_partitioned_bit16_groups, + dp_process_group=self.real_dp_process_group, + start_alignment_factor=self.nccl_start_alignment_factor, + allgather_bucket_size=self.allgather_bucket_size) + def _average_expert_grad_norms(self, norm_groups): for i, norm in enumerate(norm_groups): if self.is_moe_param_group[i]: @@ -2058,6 +2116,7 @@ def state_dict(self): state_dict[PARTITION_COUNT] = self.partition_count state_dict[DS_VERSION] = version + state_dict[PARAM_SLICE_MAPPINGS] = self._param_slice_mappings return state_dict @@ -2169,6 +2228,45 @@ def load_state_dict(self, load_optimizer_states=True, load_from_fp32_weights=False, checkpoint_folder=None): + if checkpoint_folder: + self._load_universal_checkpoint(checkpoint_folder, + load_optimizer_states, + load_from_fp32_weights) + else: + self._load_legacy_checkpoint(state_dict_list, + load_optimizer_states, + load_from_fp32_weights) + + def _load_universal_checkpoint(self, + checkpoint_folder, + load_optimizer_states, + load_from_fp32_weights): + self._load_hp_checkpoint_state(checkpoint_folder) + + @property + def param_groups(self): + """Forward the wrapped optimizer's parameters.""" + return self.optimizer.param_groups + + def _load_hp_checkpoint_state(self, checkpoint_dir): + checkpoint_dir = os.path.join(checkpoint_dir, "zero") + tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) + tp_world_size = self.mpu.get_slice_parallel_world_size() + + for i, _ in enumerate(self.optimizer.param_groups): + for lp in self.bit16_groups[i]: + if lp._hp_mapping is not None: + #print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}") + lp.load_hp_checkpoint_state( + os.path.join(checkpoint_dir, + self.param_names[lp]), + tp_rank, + tp_world_size) + + def _load_legacy_checkpoint(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False): r"""Loading ZeRO checkpoint Arguments: @@ -2269,6 +2367,9 @@ def load_state_dict(self, # option 1 from above self._restore_from_bit16_weights() + if load_optimizer_states: + self._link_all_hp_params() + def _handle_overflow(cpu_sum, x, i): import math diff --git a/deepspeed/utils/tensor_fragment.py b/deepspeed/utils/tensor_fragment.py index 913b188df9a9..37e394b666e5 100644 --- a/deepspeed/utils/tensor_fragment.py +++ b/deepspeed/utils/tensor_fragment.py @@ -82,14 +82,13 @@ def get_hp_fragment_mapping(lp_param, hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel) - optim_fragment = { key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel) for key, value in optimizer_state_dict.items() - if torch.is_tensor(value) and value.dim() > 0 + if torch.is_tensor(value) and value.shape == flat_hp_partition.shape } lp_frag_address = fragment_address(start=fragment_start - lp_start, @@ -103,3 +102,86 @@ def get_hp_fragment_mapping(lp_param, hp_fragment=hp_fragment_tensor, hp_fragment_address=hp_frag_address, optim_fragment=optim_fragment) + + +''' +Logic for lp_param to hp_param mapping + +lp lp0 lp1 lp2 lp3 lp4 <------- indices/names +lp [ ][ ][ ][ ][ ] <-------- tensors +flat_lp [ ] <-------- flat lp params +flat_hp [ ] <------------------ flat hp partition on current rank +full_hp [ ] <------- full flat hp params + + +lp2 + full numel = 16 + lp_frag + numel = 12 + frag_start = 3 + frag_end = 15 + hp_frag + numel = 12 + frag_start = 0 + frag_end = 11 + + hp_frag.copy_(lp_frag) + + +lp3: + full numel = 4 + lp_frag + numel = 4 + start = 0 + end = 3 + hp_frag + numel = 4 + start = 12 + end = 15 + + +lp4: + full numel = 12 + lp_frag + numel = 4 + start = 0 + end = 3 + hp_frag + numel = 4 + start = 16 + end = 19 + + + +Visual depiction of above +lp { } +flat_lp [ ] +flat_hp ( ) + + +flat_lp [ { ( } ) ] + lx hx ly hy + ly-hx + + +lp { } +flat_lp [ ] +flat_hp ( ) + + +flat_lp [ ( { ) } ] + hx lx hy ly + hy-lx + +lp { } +flat_lp [ ] +flat_hp ( ) + + +flat_lp [ ( { } ) ] + hx lx ly hy + ly-lx + +lp -> (lx, hy) +flat_hp -> (hx, hy) +'''