From 3afaf2df34b88764a008a057c0d96c37c52b00dd Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Mon, 6 Sep 2021 14:54:06 +0000 Subject: [PATCH 01/10] Checkpoint conversion tools --- tools/convert_checkpoint/README.md | 26 ++++ .../deepspeed_checkpoint.py | 133 ++++++++++++++++++ .../deepspeed_to_megatron.py | 102 ++++++++++++++ .../convert_checkpoint/inspect_checkpoint.py | 39 +++++ .../inspect_deepspeed_checkpoint.py | 76 ++++++++++ 5 files changed, 376 insertions(+) create mode 100644 tools/convert_checkpoint/README.md create mode 100644 tools/convert_checkpoint/deepspeed_checkpoint.py create mode 100644 tools/convert_checkpoint/deepspeed_to_megatron.py create mode 100644 tools/convert_checkpoint/inspect_checkpoint.py create mode 100644 tools/convert_checkpoint/inspect_deepspeed_checkpoint.py diff --git a/tools/convert_checkpoint/README.md b/tools/convert_checkpoint/README.md new file mode 100644 index 00000000000..a5edf5dacbb --- /dev/null +++ b/tools/convert_checkpoint/README.md @@ -0,0 +1,26 @@ +# Introduction +This folder is a collection of scripts for converting checkpoints of one training framework (e.g., DeepSpeed) into that of a different framework (e.g., Megatron-LM). inspecting checkpoints. The folder also contains scripts for inspecting checkpoint files and folders, which could be useful when developing checkpoint conversion logic. At the time of creation, this folder contains scripts to convert DeepSpeed checkpoints to Megatron-LM checkpoints (this motivated this effort as part of the BigScience project). + +Here are the list and details of checkpoint conversions provided by the available scripts. +1. [DeepSpeed to Megatron-LM](#DeepSpeed to Megatron-LM) + + +## DeepSpeed to Megatron-LM +The (current implementation of the) converter extracts model paramemters from a DeepSpeed checkpoint (i.e., excludes other training states such as args, optimizer, scheduler, etc) and convert into a Megatron-LM checkpoint similarly containing only model parameters. The resulting Megatron-LM checkpoint could be loaded into Megatron-LM framework for finetuning or inference. Tensor parallelism (TP) and pipeline parallelism (PP) are supported in the sense that the generated Megatron-LM checkpoint (folders and files) will be of the same TP and PP of the training that created the input DeepSpeed checkpoint. The entry point of the converter is `deepspeed_to_megatron.py`, which as the following usage: +```bash +python tools/convert_checkpoint/deepspeed_to_megatron.py -h +Convert DeepSpeed Checkpoint to Megatron Checkpoint +usage: deepspeed_to_megatron.py [-h] [--input_folder INPUT_FOLDER] + [--output_folder OUTPUT_FOLDER] + +optional arguments: + -h, --help show this help message and exit + --input_folder INPUT_FOLDER + Input DeepSpeed Checkpoint folder + --output_folder OUTPUT_FOLDER + Output Megatron checkpoint folder +``` + +The following scripts which proved useful for debugging are also included: +1. `inspect_deepspeed_checkpoint.py` to view the contents of a DeepSpeed checkpoint folder. +2. `inspect_checkpoint.py` to view the contents of a PyTorch checkpoint file. \ No newline at end of file diff --git a/tools/convert_checkpoint/deepspeed_checkpoint.py b/tools/convert_checkpoint/deepspeed_checkpoint.py new file mode 100644 index 00000000000..93878f609ff --- /dev/null +++ b/tools/convert_checkpoint/deepspeed_checkpoint.py @@ -0,0 +1,133 @@ +import os +from typing import Dict +import torch + +ZERO_FILE_PREFIX = 'zero_pp_rank_' +LAYER_FILE_PREFIX = 'layer_' +MP_RANK_FILE_PREFIX = 'mp_rank_' +EMBEDDING_LAYER_INDEX = 0 +FINAL_LAYER_NORM_INDEX = -1 +class DeepSpeedCheckpoint(object): + def __init__(self, dir): + self.dir = dir + self.file_list = self._get_files(dir) + self.zero_files = self._get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX) + self.layer_files = self._get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX) + self.mp_rank_files = self._get_files_with_prefix(self.file_list, MP_RANK_FILE_PREFIX) + self.layer_keys = self._get_layer_keys() + self.layer_count = len(self.layer_keys) + self.tp_degree = len(self._get_files_with_prefix(self.layer_files, f'{LAYER_FILE_PREFIX}01')) + self.pp_degree = len(self.mp_rank_files) // self.tp_degree + self.dp_degree = len(self.zero_files) // (self.pp_degree * self.tp_degree) + self._sanity_check() + self.pp_to_transformer_map = self._build_pp_transformer_map() + self.transformer_file_map = self._build_transformer_file_map() + self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) + self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX) + + + def show_layer_file_map(self): + self._dump_file_map(self.layer_file_map) + + def show_tp_embedding_map(self): + self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers') + + def show_tp_final_norm_map(self): + self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers') + + def show_pp_tranformer_map(self): + self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers') + + def show_transformer_file_map(self): + self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files') + + def get_embedding_state(self, tp_index: int) -> Dict: + embedding_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[EMBEDDING_LAYER_INDEX]) + assert tp_index < len(embedding_files) + sd = torch.load(embedding_files[tp_index]) + return sd + + def get_transformer_state(self, tp_index: int, pp_index: int) -> list: + assert tp_index < self.tp_degree + assert pp_index < self.pp_degree + t_list = [] + for fname in self.transformer_file_map[(tp_index, pp_index)]: + sd = torch.load(fname) + t_list.append(sd) + return t_list + + def get_final_norm_state(self, tp_index:int) -> Dict: + final_norm_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[FINAL_LAYER_NORM_INDEX]) + assert tp_index < len(final_norm_files) + sd = torch.load(final_norm_files[tp_index]) + return sd + + def _build_tp_other_layer_map(self, layer_index:int): + assert layer_index < len(self.layer_files) + layer_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[layer_index]) + data_map = {i:fname for i, fname in enumerate(layer_files)} + return data_map + + def _build_pp_transformer_map(self): + data_map = {} + transformer_layers = self.layer_keys[1:-1] + layers_per_pp = len(transformer_layers) // self.pp_degree + data_map = {i:transformer_layers[i*layers_per_pp:(i+1)*layers_per_pp] for i in range(0, self.pp_degree)} + return data_map + + def _dump_mapping(self, data_map, map_tag = None): + if map_tag is not None: + print(f'Dump mapping: {map_tag}') + for k, v in data_map.items(): + print(f'{k} = {v}') + + def _build_transformer_file_map(self): + transformer_layer_keys = self.layer_keys[1:-1] + file_map = {} + layers_per_pp = len(transformer_layer_keys) // self.pp_degree + for key_index, layer_key in enumerate(transformer_layer_keys): + pp_index = key_index // layers_per_pp + layer_files = self._get_files_with_prefix(self.layer_files, layer_key) + assert len(layer_files) == self.tp_degree + for file_index, fname in enumerate(layer_files): + map_key = (file_index, pp_index) + if not map_key in file_map.keys(): + file_map[map_key] = [] + file_map[map_key].append(fname) + + return file_map + + def _sanity_check(self): + assert len(self.mp_rank_files) % self.tp_degree == 0 + assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0 + assert len(self.layer_keys) > 2 + assert (len(self.layer_keys) - 2) % self.pp_degree == 0 + + def _get_files_with_prefix(self, all_files, prefix): + file_list = [] + for file_path in all_files: + _, fname = os.path.split(file_path) + if fname.startswith(prefix): + file_list.append(file_path) + + return sorted(file_list) + + def validate_files(self): + for file in self.file_list: + if not os.path.isfile(file): + print(f'Error: {file} is not existent') + + def _get_files(self, dir): + file_list = [] + for root, dirs, files in os.walk(dir): + for file in files: + file_list.append(os.path.join(root, file)) + return file_list + + def _get_layer_keys(self): + key_set = set() + key_len = len(LAYER_FILE_PREFIX) + 2 + for file_path in self.layer_files: + _, fname = os.path.split(file_path) + key_set.add(fname[:key_len]) + return sorted(list(key_set)) \ No newline at end of file diff --git a/tools/convert_checkpoint/deepspeed_to_megatron.py b/tools/convert_checkpoint/deepspeed_to_megatron.py new file mode 100644 index 00000000000..3dcec5945be --- /dev/null +++ b/tools/convert_checkpoint/deepspeed_to_megatron.py @@ -0,0 +1,102 @@ +import argparse +import os +import torch +from collections import OrderedDict +from deepspeed_checkpoint import DeepSpeedCheckpoint + +MODEL_KEY = 'model' +LANGUGAGE_MODEL_KEY = 'language_model' +EMBEDDING_KEY = 'embedding' +ENCODER_KEY = 'encoder' +WORD_EMBEDDINGS_FOR_HEAD_KEY = 'word_embeddings_for_head' +WORD_EMBEDDINGS_KEY = 'word_embeddings' +FINAL_LAYER_NORM_KEY ='final_layernorm' + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--input_folder', default=None, type=str, help='Input DeepSpeed Checkpoint folder') + parser.add_argument('--output_folder', default=None, type=str, help='Output Megatron checkpoint folder') + args = parser.parse_args() + print(f'args = {args}') + return args + + +def _convert_ds_transformer_state(sd_list): + new_sd = OrderedDict() + for i, sd in enumerate(sd_list): + for key, value in sd.items(): + new_key = f'layers.{i}.{key}' + new_sd[new_key] = value + + return new_sd + +def _create_checkpoint_path(tp_index, pp_index): + rank_folder = f'mp_rank_{tp_index:02d}_{pp_index:03d}' + return os.path.join(rank_folder, 'model_optim_rng.pt') + +def _create_megatron_dict(): + language_model_dict = { + EMBEDDING_KEY: {}, + ENCODER_KEY: {} + } + megatron_dict = { + MODEL_KEY: {LANGUGAGE_MODEL_KEY: language_model_dict}, + WORD_EMBEDDINGS_FOR_HEAD_KEY: OrderedDict() + } + return megatron_dict + + +def _save_checkpoint(file_path, chkpt_sd): + dir, _ = os.path.split(file_path) + os.makedirs(dir, exist_ok=True) + torch.save(chkpt_sd, file_path) + +def _create_rank_checkpoint(ds_checkpoint, output_folder, tp_index, pp_index): + checkpoint_path = os.path.join(output_folder, _create_checkpoint_path(tp_index, pp_index)) + + meg_encoder_sd = OrderedDict() + meg_embedding_sd = OrderedDict() + meg_embedding_for_head_sd = OrderedDict() + + transformer_sd = ds_checkpoint.get_transformer_state(tp_index, pp_index) + meg_encoder_sd.update(_convert_ds_transformer_state(transformer_sd)) + + if pp_index in [0, ds_checkpoint.pp_degree - 1]: + embedding_sd = ds_checkpoint.get_embedding_state(tp_index) + if pp_index == 0: + meg_embedding_sd.update(embedding_sd) + + if pp_index == ds_checkpoint.pp_degree -1: + for key, value in embedding_sd.items(): + if key.startswith(WORD_EMBEDDINGS_KEY): + fields = key.split('.') + new_fields = fields[1:] + new_key = '.'.join(new_fields) + meg_embedding_for_head_sd[new_key] = value + + final_norm_sd = ds_checkpoint.get_final_norm_state(tp_index) + new_final_norm_sd = {f'{FINAL_LAYER_NORM_KEY}.{key}': value for key, value in final_norm_sd.items()} + meg_encoder_sd.update(new_final_norm_sd) + + checkpoint_sd = _create_megatron_dict() + checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][EMBEDDING_KEY] = meg_embedding_sd + checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ENCODER_KEY] = meg_encoder_sd + checkpoint_sd[MODEL_KEY][WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd + + _save_checkpoint(checkpoint_path, checkpoint_sd) + + +def main(): + print(f'Convert DeepSpeed Checkpoint to Megatron Checkpoint') + + args = parse_arguments() + print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Megatron checkpoint in {args.output_folder}') + + ds_checkpoint = DeepSpeedCheckpoint(args.input_folder) + for i in range(0, ds_checkpoint.tp_degree): + for j in range(0, ds_checkpoint.pp_degree): + _create_rank_checkpoint(ds_checkpoint, args.output_folder, i, j) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tools/convert_checkpoint/inspect_checkpoint.py b/tools/convert_checkpoint/inspect_checkpoint.py new file mode 100644 index 00000000000..86d05c977fd --- /dev/null +++ b/tools/convert_checkpoint/inspect_checkpoint.py @@ -0,0 +1,39 @@ +import torch +import sys +import os +from collections import OrderedDict + + +def dump_data(datum, name_list=[]): + if type(datum) in (dict, OrderedDict): + for k, v in datum.items(): + dump_data(v, name_list+[str(k)]) + elif type(datum) in (list, tuple): + for v in datum: + dump_data(v, name_list) + elif torch.is_tensor(datum): + prefix = '.'.join(name_list) + print(f'tensor {prefix} = {datum.shape}') + else: + #pass + prefix = '.'.join(name_list) + print(f'other {prefix} = {datum}') + +def main(): + if len(sys.argv) < 2: + print(f'Usage: {sys.argv[0]} ') + exit(1) + + ckpt_file = sys.argv[1] + if not os.path.isfile(ckpt_file): + print(f'{ckpt_file} is not a valid file') + exit(1) + + print(f'loading checkpoint file: {ckpt_file}') + sd = torch.load(ckpt_file) + dump_data(sd) + + quit() + + +main() \ No newline at end of file diff --git a/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py b/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py new file mode 100644 index 00000000000..733a082b4dc --- /dev/null +++ b/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py @@ -0,0 +1,76 @@ +import argparse +from deepspeed_checkpoint import DeepSpeedCheckpoint + +def list_files(file_list, tag): + print(f'Listing files: {tag}') + for i, file in enumerate(file_list): + print(f'{i+1}: {file}') + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--folder', default=None, type=str, help='DeepSpeed Checkpoint folder') + args = parser.parse_args() + print(f'args = {args}') + return args + + +def show_input_files(ds_checkpoint): + list_files(ds_checkpoint.file_list, 'all') + list_files(ds_checkpoint.zero_files, 'zero') + list_files(ds_checkpoint.layer_files, 'layer') + list_files(ds_checkpoint.mp_rank_files, 'mp rank') + +def show_simple_state(ds_checkpoint): + print(f'layer keys = {ds_checkpoint.layer_keys}') + print(f'layer count = {ds_checkpoint.layer_count}') + + print(f'tp_degree_count = {ds_checkpoint.tp_degree}') + print(f'pp_degree_count = {ds_checkpoint.pp_degree}') + print(f'dp_degree_count = {ds_checkpoint.dp_degree}') + +def show_mappings(ds_checkpoint): + ds_checkpoint.show_pp_tranformer_map() + ds_checkpoint.show_transformer_file_map() + ds_checkpoint.show_tp_embedding_map() + ds_checkpoint.show_tp_final_norm_map() + +def show_state_summary(tag, sd): + summary = {k:v.shape for k,v in sd.items()} + print(f'{tag} = {summary}') + +def show_embedding_states(ds_checkpoint): + for i in range(0, ds_checkpoint.tp_degree): + sd = ds_checkpoint.get_embedding_state(i) + show_state_summary(f'embedding[{i}]', sd) + +def show_final_norm_states(ds_checkpoint): + for i in range(0, ds_checkpoint.tp_degree): + sd = ds_checkpoint.get_final_norm_state(i) + show_state_summary(f'final_norm[{i}]', sd) + +def show_transformer_states(ds_checkpoint): + for i in range(0, ds_checkpoint.tp_degree): + for j in range(0, ds_checkpoint.pp_degree): + state_list = ds_checkpoint.get_transformer_state(tp_index=i, pp_index=j) + print(f'tp_pp_rank[{i},{j}] = ') + for k, sd in enumerate(state_list): + show_state_summary(f' block[{k}]', sd) + print("") + + +def main(): + print(f'Inspecting DeepSpeed Checkpoint') + args = parse_arguments() + + ds_checkpoint = DeepSpeedCheckpoint(args.folder) + ds_checkpoint.validate_files() + + show_input_files(ds_checkpoint) + show_simple_state(ds_checkpoint) + show_mappings(ds_checkpoint) + show_embedding_states(ds_checkpoint) + show_final_norm_states(ds_checkpoint) + show_transformer_states(ds_checkpoint) + +if __name__ == "__main__": + main() \ No newline at end of file From 6f6b3420e16ff5486a2a17be4ced6800bcbf6bad Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Mon, 6 Sep 2021 15:12:31 +0000 Subject: [PATCH 02/10] Fix formatting --- tools/convert_checkpoint/deepspeed_checkpoint.py | 3 ++- tools/convert_checkpoint/deepspeed_to_megatron.py | 2 +- tools/convert_checkpoint/inspect_checkpoint.py | 3 ++- tools/convert_checkpoint/inspect_deepspeed_checkpoint.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tools/convert_checkpoint/deepspeed_checkpoint.py b/tools/convert_checkpoint/deepspeed_checkpoint.py index 93878f609ff..a1fd06e6b60 100644 --- a/tools/convert_checkpoint/deepspeed_checkpoint.py +++ b/tools/convert_checkpoint/deepspeed_checkpoint.py @@ -130,4 +130,5 @@ def _get_layer_keys(self): for file_path in self.layer_files: _, fname = os.path.split(file_path) key_set.add(fname[:key_len]) - return sorted(list(key_set)) \ No newline at end of file + return sorted(list(key_set)) + \ No newline at end of file diff --git a/tools/convert_checkpoint/deepspeed_to_megatron.py b/tools/convert_checkpoint/deepspeed_to_megatron.py index 3dcec5945be..21b7239857e 100644 --- a/tools/convert_checkpoint/deepspeed_to_megatron.py +++ b/tools/convert_checkpoint/deepspeed_to_megatron.py @@ -99,4 +99,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tools/convert_checkpoint/inspect_checkpoint.py b/tools/convert_checkpoint/inspect_checkpoint.py index 86d05c977fd..9b092020aad 100644 --- a/tools/convert_checkpoint/inspect_checkpoint.py +++ b/tools/convert_checkpoint/inspect_checkpoint.py @@ -36,4 +36,5 @@ def main(): quit() -main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py b/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py index 733a082b4dc..4172e73c59a 100644 --- a/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py +++ b/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py @@ -73,4 +73,4 @@ def main(): show_transformer_states(ds_checkpoint) if __name__ == "__main__": - main() \ No newline at end of file + main() From 0fa9543d71e52e34fbef3b0689f058731846d124 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Thu, 9 Sep 2021 01:13:11 +0000 Subject: [PATCH 03/10] 1) Provide args in converted checkpoint 2) Reshape TP and PP degrees --- tools/convert_checkpoint/README.md | 12 ++- .../deepspeed_checkpoint.py | 77 ++++++++++++++----- .../deepspeed_to_megatron.py | 33 +++++--- .../inspect_deepspeed_checkpoint.py | 6 +- 4 files changed, 94 insertions(+), 34 deletions(-) diff --git a/tools/convert_checkpoint/README.md b/tools/convert_checkpoint/README.md index a5edf5dacbb..7c56bf3ca3b 100644 --- a/tools/convert_checkpoint/README.md +++ b/tools/convert_checkpoint/README.md @@ -6,12 +6,14 @@ Here are the list and details of checkpoint conversions provided by the availabl ## DeepSpeed to Megatron-LM -The (current implementation of the) converter extracts model paramemters from a DeepSpeed checkpoint (i.e., excludes other training states such as args, optimizer, scheduler, etc) and convert into a Megatron-LM checkpoint similarly containing only model parameters. The resulting Megatron-LM checkpoint could be loaded into Megatron-LM framework for finetuning or inference. Tensor parallelism (TP) and pipeline parallelism (PP) are supported in the sense that the generated Megatron-LM checkpoint (folders and files) will be of the same TP and PP of the training that created the input DeepSpeed checkpoint. The entry point of the converter is `deepspeed_to_megatron.py`, which as the following usage: +The (current implementation of the) converter extracts args and model paramemters from a DeepSpeed checkpoint (i.e., excludes other training states such as optimizer, scheduler, etc) and convert into a Megatron-LM checkpoint similarly containing only model parameters. The converter also provides a best-effort attempt to reshape the tensor-parallelism and pipeline parallelism degrees for the checkpoint. The resulting Megatron-LM checkpoint could be loaded into Megatron-LM framework for finetuning or inference. Tensor parallelism (TP) and pipeline parallelism (PP) are supported in the sense that the generated Megatron-LM checkpoint (folders and files) will be of the same TP and PP of the training that created the input DeepSpeed checkpoint. The entry point of the converter is `deepspeed_to_megatron.py`, which as the following usage: ```bash python tools/convert_checkpoint/deepspeed_to_megatron.py -h Convert DeepSpeed Checkpoint to Megatron Checkpoint usage: deepspeed_to_megatron.py [-h] [--input_folder INPUT_FOLDER] [--output_folder OUTPUT_FOLDER] + [--target_tp TARGET_TP] + [--target_pp TARGET_PP] optional arguments: -h, --help show this help message and exit @@ -19,8 +21,12 @@ optional arguments: Input DeepSpeed Checkpoint folder --output_folder OUTPUT_FOLDER Output Megatron checkpoint folder + --target_tp TARGET_TP + Target TP degree + --target_pp TARGET_PP + Target PP degree ``` The following scripts which proved useful for debugging are also included: -1. `inspect_deepspeed_checkpoint.py` to view the contents of a DeepSpeed checkpoint folder. -2. `inspect_checkpoint.py` to view the contents of a PyTorch checkpoint file. \ No newline at end of file +1. `inspect_deepspeed_checkpoint.py`: view the contents of a DeepSpeed checkpoint folder. +2. `inspect_checkpoint.py`: view the contents of a PyTorch checkpoint file. \ No newline at end of file diff --git a/tools/convert_checkpoint/deepspeed_checkpoint.py b/tools/convert_checkpoint/deepspeed_checkpoint.py index a1fd06e6b60..b68d227dbf3 100644 --- a/tools/convert_checkpoint/deepspeed_checkpoint.py +++ b/tools/convert_checkpoint/deepspeed_checkpoint.py @@ -7,8 +7,22 @@ MP_RANK_FILE_PREFIX = 'mp_rank_' EMBEDDING_LAYER_INDEX = 0 FINAL_LAYER_NORM_INDEX = -1 +ARGS_KEY = 'args' +SEQUENTIAL_LAYERS = [ + 'input_layernorm.weight', 'input_layernorm.bias', + 'self_attention.dense.bias', + 'post_attention_layernorm.weight', 'post_attention_layernorm.bias', + 'mlp.dense_4h_to_h.bias', + 'position_embeddings.weight' +] + +LAYER_CONCAT_DIM = { + 'self_attention.dense.weight': 1, + 'mlp.dense_4h_to_h.weight': 1 +} + class DeepSpeedCheckpoint(object): - def __init__(self, dir): + def __init__(self, dir, tp_degree=None, pp_degree=None): self.dir = dir self.file_list = self._get_files(dir) self.zero_files = self._get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX) @@ -16,9 +30,12 @@ def __init__(self, dir): self.mp_rank_files = self._get_files_with_prefix(self.file_list, MP_RANK_FILE_PREFIX) self.layer_keys = self._get_layer_keys() self.layer_count = len(self.layer_keys) - self.tp_degree = len(self._get_files_with_prefix(self.layer_files, f'{LAYER_FILE_PREFIX}01')) - self.pp_degree = len(self.mp_rank_files) // self.tp_degree - self.dp_degree = len(self.zero_files) // (self.pp_degree * self.tp_degree) + self.original_tp_degree = len(self._get_files_with_prefix(self.layer_files, f'{LAYER_FILE_PREFIX}01')) + self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree + self.dp_degree = len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree) + self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree + self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree + self._sanity_check() self.pp_to_transformer_map = self._build_pp_transformer_map() self.transformer_file_map = self._build_transformer_file_map() @@ -26,9 +43,6 @@ def __init__(self, dir): self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX) - def show_layer_file_map(self): - self._dump_file_map(self.layer_file_map) - def show_tp_embedding_map(self): self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers') @@ -42,30 +56,35 @@ def show_transformer_file_map(self): self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files') def get_embedding_state(self, tp_index: int) -> Dict: - embedding_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[EMBEDDING_LAYER_INDEX]) - assert tp_index < len(embedding_files) - sd = torch.load(embedding_files[tp_index]) + assert tp_index in self.tp_to_embedding_map.keys() + sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]] + sd = self._merge_state_dicts(sd_list) return sd + def get_args(self): + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + return sd[ARGS_KEY] if ARGS_KEY in sd.keys() else None + def get_transformer_state(self, tp_index: int, pp_index: int) -> list: assert tp_index < self.tp_degree assert pp_index < self.pp_degree t_list = [] - for fname in self.transformer_file_map[(tp_index, pp_index)]: - sd = torch.load(fname) + for fname_list in self.transformer_file_map[(tp_index, pp_index)]: + sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] + sd = self._merge_state_dicts(sd_list) t_list.append(sd) return t_list def get_final_norm_state(self, tp_index:int) -> Dict: - final_norm_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[FINAL_LAYER_NORM_INDEX]) - assert tp_index < len(final_norm_files) - sd = torch.load(final_norm_files[tp_index]) + assert tp_index in self.tp_to_final_norm_map.keys() + sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu')) return sd def _build_tp_other_layer_map(self, layer_index:int): assert layer_index < len(self.layer_files) layer_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[layer_index]) - data_map = {i:fname for i, fname in enumerate(layer_files)} + layer_file_partitions = self._partition_data(layer_files, self.tp_degree) + data_map = {i:flist for i, flist in enumerate(layer_file_partitions)} return data_map def _build_pp_transformer_map(self): @@ -88,12 +107,12 @@ def _build_transformer_file_map(self): for key_index, layer_key in enumerate(transformer_layer_keys): pp_index = key_index // layers_per_pp layer_files = self._get_files_with_prefix(self.layer_files, layer_key) - assert len(layer_files) == self.tp_degree - for file_index, fname in enumerate(layer_files): - map_key = (file_index, pp_index) + layer_file_partitions = self._partition_data(layer_files, self.tp_degree) + for tp_index in range(self.tp_degree): + map_key = (tp_index, pp_index) if not map_key in file_map.keys(): file_map[map_key] = [] - file_map[map_key].append(fname) + file_map[map_key].append(layer_file_partitions[tp_index]) return file_map @@ -131,4 +150,20 @@ def _get_layer_keys(self): _, fname = os.path.split(file_path) key_set.add(fname[:key_len]) return sorted(list(key_set)) - \ No newline at end of file + + def _partition_data(self, data_list, num_partitions): + num_elems = len(data_list) + assert num_elems % num_partitions == 0 + partition_size = num_elems // num_partitions + partitions_list = [data_list[i:i+partition_size] for i in range(0, num_elems, partition_size)] + return partitions_list + + def _merge_state_dicts(self, sd_list): + merged_sd = {} + for key in sd_list[0].keys(): + if not key in SEQUENTIAL_LAYERS: + cat_dim = LAYER_CONCAT_DIM.get(key, 0) + merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim) + else: + merged_sd[key] = sd_list[0][key] + return merged_sd diff --git a/tools/convert_checkpoint/deepspeed_to_megatron.py b/tools/convert_checkpoint/deepspeed_to_megatron.py index 21b7239857e..763593060bc 100644 --- a/tools/convert_checkpoint/deepspeed_to_megatron.py +++ b/tools/convert_checkpoint/deepspeed_to_megatron.py @@ -2,20 +2,25 @@ import os import torch from collections import OrderedDict -from deepspeed_checkpoint import DeepSpeedCheckpoint +from deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint MODEL_KEY = 'model' +ARGS_KEY = 'args' LANGUGAGE_MODEL_KEY = 'language_model' EMBEDDING_KEY = 'embedding' ENCODER_KEY = 'encoder' WORD_EMBEDDINGS_FOR_HEAD_KEY = 'word_embeddings_for_head' WORD_EMBEDDINGS_KEY = 'word_embeddings' FINAL_LAYER_NORM_KEY ='final_layernorm' +CHECKPOINT_VERSION_KEY = 'checkpoint_version' +CHECKPOINT_VERSION_VALUE = 3.0 def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--input_folder', default=None, type=str, help='Input DeepSpeed Checkpoint folder') parser.add_argument('--output_folder', default=None, type=str, help='Output Megatron checkpoint folder') + parser.add_argument('--target_tp', default=None, type=int, help='Target TP degree') + parser.add_argument('--target_pp', default=None, type=int, help='Target PP degree') args = parser.parse_args() print(f'args = {args}') return args @@ -30,9 +35,17 @@ def _convert_ds_transformer_state(sd_list): return new_sd -def _create_checkpoint_path(tp_index, pp_index): - rank_folder = f'mp_rank_{tp_index:02d}_{pp_index:03d}' - return os.path.join(rank_folder, 'model_optim_rng.pt') +def _create_checkpoint_paths(base_folder, tp_degree, pp_degree): + path_list = [] + for i in range(0, tp_degree): + path_list.append([]) + for j in range(0, pp_degree): + rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}' + ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt') + path_list[i].append(os.path.join(base_folder, ckpt_path)) + + return path_list + def _create_megatron_dict(): language_model_dict = { @@ -41,7 +54,8 @@ def _create_megatron_dict(): } megatron_dict = { MODEL_KEY: {LANGUGAGE_MODEL_KEY: language_model_dict}, - WORD_EMBEDDINGS_FOR_HEAD_KEY: OrderedDict() + WORD_EMBEDDINGS_FOR_HEAD_KEY: OrderedDict(), + CHECKPOINT_VERSION_KEY: CHECKPOINT_VERSION_VALUE } return megatron_dict @@ -51,9 +65,8 @@ def _save_checkpoint(file_path, chkpt_sd): os.makedirs(dir, exist_ok=True) torch.save(chkpt_sd, file_path) -def _create_rank_checkpoint(ds_checkpoint, output_folder, tp_index, pp_index): - checkpoint_path = os.path.join(output_folder, _create_checkpoint_path(tp_index, pp_index)) +def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index): meg_encoder_sd = OrderedDict() meg_embedding_sd = OrderedDict() meg_embedding_for_head_sd = OrderedDict() @@ -82,6 +95,7 @@ def _create_rank_checkpoint(ds_checkpoint, output_folder, tp_index, pp_index): checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][EMBEDDING_KEY] = meg_embedding_sd checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ENCODER_KEY] = meg_encoder_sd checkpoint_sd[MODEL_KEY][WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd + checkpoint_sd[ARGS_KEY] = ds_checkpoint.get_args() _save_checkpoint(checkpoint_path, checkpoint_sd) @@ -92,10 +106,11 @@ def main(): args = parse_arguments() print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Megatron checkpoint in {args.output_folder}') - ds_checkpoint = DeepSpeedCheckpoint(args.input_folder) + ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp) + checkpoint_paths = _create_checkpoint_paths(args.output_folder, ds_checkpoint.tp_degree, ds_checkpoint.pp_degree) for i in range(0, ds_checkpoint.tp_degree): for j in range(0, ds_checkpoint.pp_degree): - _create_rank_checkpoint(ds_checkpoint, args.output_folder, i, j) + _create_rank_checkpoint(ds_checkpoint, checkpoint_paths[i][j], i, j) if __name__ == "__main__": diff --git a/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py b/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py index 4172e73c59a..3125f7d9a78 100644 --- a/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py +++ b/tools/convert_checkpoint/inspect_deepspeed_checkpoint.py @@ -9,6 +9,8 @@ def list_files(file_list, tag): def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--folder', default=None, type=str, help='DeepSpeed Checkpoint folder') + parser.add_argument('--target_tp', default=None, type=int, help='Target TP degree') + parser.add_argument('--target_pp', default=None, type=int, help='Target PP degree') args = parser.parse_args() print(f'args = {args}') return args @@ -62,7 +64,7 @@ def main(): print(f'Inspecting DeepSpeed Checkpoint') args = parse_arguments() - ds_checkpoint = DeepSpeedCheckpoint(args.folder) + ds_checkpoint = DeepSpeedCheckpoint(args.folder, args.target_tp, args.target_pp) ds_checkpoint.validate_files() show_input_files(ds_checkpoint) @@ -71,6 +73,8 @@ def main(): show_embedding_states(ds_checkpoint) show_final_norm_states(ds_checkpoint) show_transformer_states(ds_checkpoint) + checkpoint_args = ds_checkpoint.get_args() + print(f'checkpoint args = {checkpoint_args}') if __name__ == "__main__": main() From 4ff8b621d838d28338aa02217abe127e3b5d2cd1 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Thu, 9 Sep 2021 12:45:32 +0000 Subject: [PATCH 04/10] Fix typo --- tools/convert_checkpoint/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/convert_checkpoint/README.md b/tools/convert_checkpoint/README.md index 7c56bf3ca3b..5e82d566339 100644 --- a/tools/convert_checkpoint/README.md +++ b/tools/convert_checkpoint/README.md @@ -6,7 +6,7 @@ Here are the list and details of checkpoint conversions provided by the availabl ## DeepSpeed to Megatron-LM -The (current implementation of the) converter extracts args and model paramemters from a DeepSpeed checkpoint (i.e., excludes other training states such as optimizer, scheduler, etc) and convert into a Megatron-LM checkpoint similarly containing only model parameters. The converter also provides a best-effort attempt to reshape the tensor-parallelism and pipeline parallelism degrees for the checkpoint. The resulting Megatron-LM checkpoint could be loaded into Megatron-LM framework for finetuning or inference. Tensor parallelism (TP) and pipeline parallelism (PP) are supported in the sense that the generated Megatron-LM checkpoint (folders and files) will be of the same TP and PP of the training that created the input DeepSpeed checkpoint. The entry point of the converter is `deepspeed_to_megatron.py`, which as the following usage: +The (current implementation of the) converter extracts args and model parameters from a DeepSpeed checkpoint (i.e., excludes other training states such as optimizer, scheduler, etc) and convert into a Megatron-LM checkpoint similarly containing only model parameters. The converter also provides a best-effort attempt to reshape the tensor-parallelism and pipeline parallelism degrees for the checkpoint. The resulting Megatron-LM checkpoint could be loaded into Megatron-LM framework for finetuning or inference. Tensor parallelism (TP) and pipeline parallelism (PP) are supported in the sense that the generated Megatron-LM checkpoint (folders and files) will be of the same TP and PP of the training that created the input DeepSpeed checkpoint. The entry point of the converter is `deepspeed_to_megatron.py`, which as the following usage: ```bash python tools/convert_checkpoint/deepspeed_to_megatron.py -h Convert DeepSpeed Checkpoint to Megatron Checkpoint From 03ddc9c3a2b9f2b15a2353706c7338c3ee6f8011 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Thu, 9 Sep 2021 12:57:52 +0000 Subject: [PATCH 05/10] Fix link --- tools/convert_checkpoint/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/convert_checkpoint/README.md b/tools/convert_checkpoint/README.md index 5e82d566339..8f527179fe9 100644 --- a/tools/convert_checkpoint/README.md +++ b/tools/convert_checkpoint/README.md @@ -2,10 +2,10 @@ This folder is a collection of scripts for converting checkpoints of one training framework (e.g., DeepSpeed) into that of a different framework (e.g., Megatron-LM). inspecting checkpoints. The folder also contains scripts for inspecting checkpoint files and folders, which could be useful when developing checkpoint conversion logic. At the time of creation, this folder contains scripts to convert DeepSpeed checkpoints to Megatron-LM checkpoints (this motivated this effort as part of the BigScience project). Here are the list and details of checkpoint conversions provided by the available scripts. -1. [DeepSpeed to Megatron-LM](#DeepSpeed to Megatron-LM) +1. [DeepSpeed to Megatron-LM](#DeepSpeed-to-Megatron) -## DeepSpeed to Megatron-LM +## DeepSpeed to Megatron The (current implementation of the) converter extracts args and model parameters from a DeepSpeed checkpoint (i.e., excludes other training states such as optimizer, scheduler, etc) and convert into a Megatron-LM checkpoint similarly containing only model parameters. The converter also provides a best-effort attempt to reshape the tensor-parallelism and pipeline parallelism degrees for the checkpoint. The resulting Megatron-LM checkpoint could be loaded into Megatron-LM framework for finetuning or inference. Tensor parallelism (TP) and pipeline parallelism (PP) are supported in the sense that the generated Megatron-LM checkpoint (folders and files) will be of the same TP and PP of the training that created the input DeepSpeed checkpoint. The entry point of the converter is `deepspeed_to_megatron.py`, which as the following usage: ```bash python tools/convert_checkpoint/deepspeed_to_megatron.py -h From 2758f21ec6155efc02ee432905d109d51f9a440f Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Thu, 9 Sep 2021 12:59:57 +0000 Subject: [PATCH 06/10] Tweak tag --- tools/convert_checkpoint/inspect_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/convert_checkpoint/inspect_checkpoint.py b/tools/convert_checkpoint/inspect_checkpoint.py index 9b092020aad..5ee955bb480 100644 --- a/tools/convert_checkpoint/inspect_checkpoint.py +++ b/tools/convert_checkpoint/inspect_checkpoint.py @@ -13,11 +13,11 @@ def dump_data(datum, name_list=[]): dump_data(v, name_list) elif torch.is_tensor(datum): prefix = '.'.join(name_list) - print(f'tensor {prefix} = {datum.shape}') + print(f'[tensor] {prefix} = {datum.shape}') else: #pass prefix = '.'.join(name_list) - print(f'other {prefix} = {datum}') + print(f'[other] {prefix} = {datum}') def main(): if len(sys.argv) < 2: From 872d63cca513deea30e1a6dffa68734ce026283f Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Fri, 10 Sep 2021 21:12:37 +0000 Subject: [PATCH 07/10] Fix converted TP and PP sizes --- tools/convert_checkpoint/deepspeed_to_megatron.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tools/convert_checkpoint/deepspeed_to_megatron.py b/tools/convert_checkpoint/deepspeed_to_megatron.py index 763593060bc..3e0f737546f 100644 --- a/tools/convert_checkpoint/deepspeed_to_megatron.py +++ b/tools/convert_checkpoint/deepspeed_to_megatron.py @@ -95,7 +95,11 @@ def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index): checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][EMBEDDING_KEY] = meg_embedding_sd checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ENCODER_KEY] = meg_encoder_sd checkpoint_sd[MODEL_KEY][WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd + checkpoint_sd[ARGS_KEY] = ds_checkpoint.get_args() + # Adjust specific fields + checkpoint_sd[ARGS_KEY].tensor_model_parallel_size = ds_checkpoint.tp_degree + checkpoint_sd[ARGS_KEY].pipeline_model_parallel_size = ds_checkpoint.pp_degree _save_checkpoint(checkpoint_path, checkpoint_sd) From f6d2fa4d19c0eb8b2d7b93ac123828a74d661b4d Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Fri, 10 Sep 2021 21:20:11 +0000 Subject: [PATCH 08/10] For release mode --- tools/convert_checkpoint/deepspeed_to_megatron.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tools/convert_checkpoint/deepspeed_to_megatron.py b/tools/convert_checkpoint/deepspeed_to_megatron.py index 3e0f737546f..6131d149692 100644 --- a/tools/convert_checkpoint/deepspeed_to_megatron.py +++ b/tools/convert_checkpoint/deepspeed_to_megatron.py @@ -21,6 +21,7 @@ def parse_arguments(): parser.add_argument('--output_folder', default=None, type=str, help='Output Megatron checkpoint folder') parser.add_argument('--target_tp', default=None, type=int, help='Target TP degree') parser.add_argument('--target_pp', default=None, type=int, help='Target PP degree') + parser.add_argument('--for_release', action='store_true', help='Convert for release purpose, reset some (progress) counters.') args = parser.parse_args() print(f'args = {args}') return args @@ -66,7 +67,7 @@ def _save_checkpoint(file_path, chkpt_sd): torch.save(chkpt_sd, file_path) -def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index): +def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, for_release=False): meg_encoder_sd = OrderedDict() meg_embedding_sd = OrderedDict() meg_embedding_for_head_sd = OrderedDict() @@ -100,6 +101,9 @@ def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index): # Adjust specific fields checkpoint_sd[ARGS_KEY].tensor_model_parallel_size = ds_checkpoint.tp_degree checkpoint_sd[ARGS_KEY].pipeline_model_parallel_size = ds_checkpoint.pp_degree + if for_release: + checkpoint_sd[ARGS_KEY].consumed_train_samples = 0 + checkpoint_sd[ARGS_KEY].consumed_valid_samples = 0 _save_checkpoint(checkpoint_path, checkpoint_sd) @@ -114,7 +118,7 @@ def main(): checkpoint_paths = _create_checkpoint_paths(args.output_folder, ds_checkpoint.tp_degree, ds_checkpoint.pp_degree) for i in range(0, ds_checkpoint.tp_degree): for j in range(0, ds_checkpoint.pp_degree): - _create_rank_checkpoint(ds_checkpoint, checkpoint_paths[i][j], i, j) + _create_rank_checkpoint(ds_checkpoint, checkpoint_paths[i][j], i, j, args.for_release) if __name__ == "__main__": From ca0ce736ae82e9ea68944d2bc79d649ff912f11e Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Fri, 10 Sep 2021 22:27:34 +0000 Subject: [PATCH 09/10] Update README --- tools/convert_checkpoint/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/convert_checkpoint/README.md b/tools/convert_checkpoint/README.md index 8f527179fe9..29cfaa805b0 100644 --- a/tools/convert_checkpoint/README.md +++ b/tools/convert_checkpoint/README.md @@ -13,7 +13,7 @@ Convert DeepSpeed Checkpoint to Megatron Checkpoint usage: deepspeed_to_megatron.py [-h] [--input_folder INPUT_FOLDER] [--output_folder OUTPUT_FOLDER] [--target_tp TARGET_TP] - [--target_pp TARGET_PP] + [--target_pp TARGET_PP] [--for_release] optional arguments: -h, --help show this help message and exit @@ -25,6 +25,8 @@ optional arguments: Target TP degree --target_pp TARGET_PP Target PP degree + --for_release Convert for release purpose, reset some (progress) + counters. ``` The following scripts which proved useful for debugging are also included: From 32f2a7fe297f405e1174d17068eb2ea27e95cb30 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Fri, 17 Sep 2021 21:56:35 +0000 Subject: [PATCH 10/10] Nested embedding dicts Iteration folder latest checkpoint version file --- .../deepspeed_checkpoint.py | 26 +++++++++++-- .../deepspeed_to_megatron.py | 39 +++++++++++++++---- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/tools/convert_checkpoint/deepspeed_checkpoint.py b/tools/convert_checkpoint/deepspeed_checkpoint.py index b68d227dbf3..c38e0d5505c 100644 --- a/tools/convert_checkpoint/deepspeed_checkpoint.py +++ b/tools/convert_checkpoint/deepspeed_checkpoint.py @@ -8,6 +8,7 @@ EMBEDDING_LAYER_INDEX = 0 FINAL_LAYER_NORM_INDEX = -1 ARGS_KEY = 'args' +ITERATION_KEY = 'iteration' SEQUENTIAL_LAYERS = [ 'input_layernorm.weight', 'input_layernorm.bias', 'self_attention.dense.bias', @@ -35,12 +36,15 @@ def __init__(self, dir, tp_degree=None, pp_degree=None): self.dp_degree = len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree) self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree - + self.global_state = {} + self._sanity_check() self.pp_to_transformer_map = self._build_pp_transformer_map() self.transformer_file_map = self._build_transformer_file_map() self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX) + self._build_global_state() + def show_tp_embedding_map(self): @@ -55,6 +59,18 @@ def show_pp_tranformer_map(self): def show_transformer_file_map(self): self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files') + def _build_global_state(self): + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) + self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) + + def get_iteration(self): + if not ITERATION_KEY in self.global_state: + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) + + return self.global_state[ITERATION_KEY] + def get_embedding_state(self, tp_index: int) -> Dict: assert tp_index in self.tp_to_embedding_map.keys() sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]] @@ -62,8 +78,12 @@ def get_embedding_state(self, tp_index: int) -> Dict: return sd def get_args(self): - sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) - return sd[ARGS_KEY] if ARGS_KEY in sd.keys() else None + if not ARGS_KEY in self.global_state: + sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) + self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) + + return self.global_state[ARGS_KEY] + def get_transformer_state(self, tp_index: int, pp_index: int) -> list: assert tp_index < self.tp_degree diff --git a/tools/convert_checkpoint/deepspeed_to_megatron.py b/tools/convert_checkpoint/deepspeed_to_megatron.py index 6131d149692..d67d1840872 100644 --- a/tools/convert_checkpoint/deepspeed_to_megatron.py +++ b/tools/convert_checkpoint/deepspeed_to_megatron.py @@ -14,6 +14,7 @@ FINAL_LAYER_NORM_KEY ='final_layernorm' CHECKPOINT_VERSION_KEY = 'checkpoint_version' CHECKPOINT_VERSION_VALUE = 3.0 +ITERATION_KEY = 'iteration' def parse_arguments(): parser = argparse.ArgumentParser() @@ -36,14 +37,15 @@ def _convert_ds_transformer_state(sd_list): return new_sd -def _create_checkpoint_paths(base_folder, tp_degree, pp_degree): +def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): path_list = [] + iter_folder = f'iter_{iteration:07d}' for i in range(0, tp_degree): path_list.append([]) for j in range(0, pp_degree): rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}' ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt') - path_list[i].append(os.path.join(base_folder, ckpt_path)) + path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path)) return path_list @@ -55,7 +57,6 @@ def _create_megatron_dict(): } megatron_dict = { MODEL_KEY: {LANGUGAGE_MODEL_KEY: language_model_dict}, - WORD_EMBEDDINGS_FOR_HEAD_KEY: OrderedDict(), CHECKPOINT_VERSION_KEY: CHECKPOINT_VERSION_VALUE } return megatron_dict @@ -67,9 +68,17 @@ def _save_checkpoint(file_path, chkpt_sd): torch.save(chkpt_sd, file_path) +def _renest_sd(sd): + new_sd = OrderedDict() + for key, value in sd.items(): + a, b = key.split('.') + new_sd[a] = {b: value} + return new_sd + + def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, for_release=False): meg_encoder_sd = OrderedDict() - meg_embedding_sd = OrderedDict() + meg_embedding_sd = OrderedDict() meg_embedding_for_head_sd = OrderedDict() transformer_sd = ds_checkpoint.get_transformer_state(tp_index, pp_index) @@ -77,8 +86,9 @@ def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, if pp_index in [0, ds_checkpoint.pp_degree - 1]: embedding_sd = ds_checkpoint.get_embedding_state(tp_index) + nested_embedding_sd = _renest_sd(embedding_sd) if pp_index == 0: - meg_embedding_sd.update(embedding_sd) + meg_embedding_sd.update(nested_embedding_sd) if pp_index == ds_checkpoint.pp_degree -1: for key, value in embedding_sd.items(): @@ -93,9 +103,14 @@ def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, meg_encoder_sd.update(new_final_norm_sd) checkpoint_sd = _create_megatron_dict() - checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][EMBEDDING_KEY] = meg_embedding_sd + + iteration = ds_checkpoint.get_iteration() + checkpoint_sd[ITERATION_KEY] = iteration + if pp_index == 0: + checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][EMBEDDING_KEY] = meg_embedding_sd checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ENCODER_KEY] = meg_encoder_sd - checkpoint_sd[MODEL_KEY][WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd + if pp_index == ds_checkpoint.pp_degree -1: + checkpoint_sd[MODEL_KEY][WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd checkpoint_sd[ARGS_KEY] = ds_checkpoint.get_args() # Adjust specific fields @@ -108,6 +123,12 @@ def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, _save_checkpoint(checkpoint_path, checkpoint_sd) +def _create_latest_file(base_folder, iteration): + file_path = os.path.join(base_folder, 'latest_checkpointed_iteration.txt') + os.makedirs(base_folder, exist_ok=True) + with open(file_path, 'w') as f: + f.write(str(iteration)) + def main(): print(f'Convert DeepSpeed Checkpoint to Megatron Checkpoint') @@ -115,7 +136,9 @@ def main(): print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Megatron checkpoint in {args.output_folder}') ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp) - checkpoint_paths = _create_checkpoint_paths(args.output_folder, ds_checkpoint.tp_degree, ds_checkpoint.pp_degree) + iteration = ds_checkpoint.get_iteration() + _create_latest_file(args.output_folder, iteration) + checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree, ds_checkpoint.pp_degree) for i in range(0, ds_checkpoint.tp_degree): for j in range(0, ds_checkpoint.pp_degree): _create_rank_checkpoint(ds_checkpoint, checkpoint_paths[i][j], i, j, args.for_release)