forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 369
Checkpoint conversion tools #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
3afaf2d
Checkpoint conversion tools
tjruwase 6f6b342
Fix formatting
tjruwase 0fa9543
1) Provide args in converted checkpoint
tjruwase 4ff8b62
Fix typo
tjruwase 03ddc9c
Fix link
tjruwase 2758f21
Tweak tag
tjruwase 872d63c
Fix converted TP and PP sizes
tjruwase f6d2fa4
For release mode
tjruwase ca0ce73
Update README
tjruwase 32f2a7f
Nested embedding dicts
tjruwase File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| # Introduction | ||
| This folder is a collection of scripts for converting checkpoints of one training framework (e.g., DeepSpeed) into that of a different framework (e.g., Megatron-LM). inspecting checkpoints. The folder also contains scripts for inspecting checkpoint files and folders, which could be useful when developing checkpoint conversion logic. At the time of creation, this folder contains scripts to convert DeepSpeed checkpoints to Megatron-LM checkpoints (this motivated this effort as part of the BigScience project). | ||
|
|
||
| Here are the list and details of checkpoint conversions provided by the available scripts. | ||
| 1. [DeepSpeed to Megatron-LM](#DeepSpeed-to-Megatron) | ||
|
|
||
|
|
||
| ## DeepSpeed to Megatron | ||
| The (current implementation of the) converter extracts args and model parameters from a DeepSpeed checkpoint (i.e., excludes other training states such as optimizer, scheduler, etc) and convert into a Megatron-LM checkpoint similarly containing only model parameters. The converter also provides a best-effort attempt to reshape the tensor-parallelism and pipeline parallelism degrees for the checkpoint. The resulting Megatron-LM checkpoint could be loaded into Megatron-LM framework for finetuning or inference. Tensor parallelism (TP) and pipeline parallelism (PP) are supported in the sense that the generated Megatron-LM checkpoint (folders and files) will be of the same TP and PP of the training that created the input DeepSpeed checkpoint. The entry point of the converter is `deepspeed_to_megatron.py`, which as the following usage: | ||
| ```bash | ||
| python tools/convert_checkpoint/deepspeed_to_megatron.py -h | ||
| Convert DeepSpeed Checkpoint to Megatron Checkpoint | ||
| usage: deepspeed_to_megatron.py [-h] [--input_folder INPUT_FOLDER] | ||
| [--output_folder OUTPUT_FOLDER] | ||
| [--target_tp TARGET_TP] | ||
| [--target_pp TARGET_PP] [--for_release] | ||
|
|
||
| optional arguments: | ||
| -h, --help show this help message and exit | ||
| --input_folder INPUT_FOLDER | ||
| Input DeepSpeed Checkpoint folder | ||
| --output_folder OUTPUT_FOLDER | ||
| Output Megatron checkpoint folder | ||
| --target_tp TARGET_TP | ||
| Target TP degree | ||
| --target_pp TARGET_PP | ||
| Target PP degree | ||
| --for_release Convert for release purpose, reset some (progress) | ||
| counters. | ||
| ``` | ||
|
|
||
| The following scripts which proved useful for debugging are also included: | ||
| 1. `inspect_deepspeed_checkpoint.py`: view the contents of a DeepSpeed checkpoint folder. | ||
| 2. `inspect_checkpoint.py`: view the contents of a PyTorch checkpoint file. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,189 @@ | ||
| import os | ||
| from typing import Dict | ||
| import torch | ||
|
|
||
| ZERO_FILE_PREFIX = 'zero_pp_rank_' | ||
| LAYER_FILE_PREFIX = 'layer_' | ||
| MP_RANK_FILE_PREFIX = 'mp_rank_' | ||
| EMBEDDING_LAYER_INDEX = 0 | ||
| FINAL_LAYER_NORM_INDEX = -1 | ||
| ARGS_KEY = 'args' | ||
| ITERATION_KEY = 'iteration' | ||
| SEQUENTIAL_LAYERS = [ | ||
| 'input_layernorm.weight', 'input_layernorm.bias', | ||
| 'self_attention.dense.bias', | ||
| 'post_attention_layernorm.weight', 'post_attention_layernorm.bias', | ||
| 'mlp.dense_4h_to_h.bias', | ||
| 'position_embeddings.weight' | ||
| ] | ||
|
|
||
| LAYER_CONCAT_DIM = { | ||
| 'self_attention.dense.weight': 1, | ||
| 'mlp.dense_4h_to_h.weight': 1 | ||
| } | ||
|
|
||
| class DeepSpeedCheckpoint(object): | ||
| def __init__(self, dir, tp_degree=None, pp_degree=None): | ||
| self.dir = dir | ||
| self.file_list = self._get_files(dir) | ||
| self.zero_files = self._get_files_with_prefix(self.file_list, ZERO_FILE_PREFIX) | ||
| self.layer_files = self._get_files_with_prefix(self.file_list, LAYER_FILE_PREFIX) | ||
| self.mp_rank_files = self._get_files_with_prefix(self.file_list, MP_RANK_FILE_PREFIX) | ||
| self.layer_keys = self._get_layer_keys() | ||
| self.layer_count = len(self.layer_keys) | ||
| self.original_tp_degree = len(self._get_files_with_prefix(self.layer_files, f'{LAYER_FILE_PREFIX}01')) | ||
| self.original_pp_degree = len(self.mp_rank_files) // self.original_tp_degree | ||
| self.dp_degree = len(self.zero_files) // (self.original_pp_degree * self.original_tp_degree) | ||
| self.tp_degree = self.original_tp_degree if tp_degree is None else tp_degree | ||
| self.pp_degree = self.original_pp_degree if pp_degree is None else pp_degree | ||
| self.global_state = {} | ||
|
|
||
| self._sanity_check() | ||
| self.pp_to_transformer_map = self._build_pp_transformer_map() | ||
| self.transformer_file_map = self._build_transformer_file_map() | ||
| self.tp_to_embedding_map = self._build_tp_other_layer_map(EMBEDDING_LAYER_INDEX) | ||
| self.tp_to_final_norm_map = self._build_tp_other_layer_map(FINAL_LAYER_NORM_INDEX) | ||
| self._build_global_state() | ||
|
|
||
|
|
||
|
|
||
| def show_tp_embedding_map(self): | ||
| self._dump_mapping(self.tp_to_embedding_map, 'tp_to_embedding_layers') | ||
|
|
||
| def show_tp_final_norm_map(self): | ||
| self._dump_mapping(self.tp_to_final_norm_map, 'tp_to_final_norm_layers') | ||
|
|
||
| def show_pp_tranformer_map(self): | ||
| self._dump_mapping(self.pp_to_transformer_map, 'pp_to_tranformer_layers') | ||
|
|
||
| def show_transformer_file_map(self): | ||
| self._dump_mapping(self.transformer_file_map, 'rank_to_tranformer_files') | ||
|
|
||
| def _build_global_state(self): | ||
| sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) | ||
| self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) | ||
| self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) | ||
|
|
||
| def get_iteration(self): | ||
| if not ITERATION_KEY in self.global_state: | ||
| sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) | ||
| self.global_state[ITERATION_KEY] = sd.get(ITERATION_KEY, 0) | ||
|
|
||
| return self.global_state[ITERATION_KEY] | ||
|
|
||
| def get_embedding_state(self, tp_index: int) -> Dict: | ||
| assert tp_index in self.tp_to_embedding_map.keys() | ||
| sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in self.tp_to_embedding_map[tp_index]] | ||
| sd = self._merge_state_dicts(sd_list) | ||
| return sd | ||
|
|
||
| def get_args(self): | ||
| if not ARGS_KEY in self.global_state: | ||
| sd = torch.load(self.mp_rank_files[0], map_location=torch.device('cpu')) | ||
| self.global_state[ARGS_KEY] = sd.get(ARGS_KEY, None) | ||
|
|
||
| return self.global_state[ARGS_KEY] | ||
|
|
||
|
|
||
| def get_transformer_state(self, tp_index: int, pp_index: int) -> list: | ||
| assert tp_index < self.tp_degree | ||
| assert pp_index < self.pp_degree | ||
| t_list = [] | ||
| for fname_list in self.transformer_file_map[(tp_index, pp_index)]: | ||
| sd_list = [torch.load(fname, map_location=torch.device('cpu')) for fname in fname_list] | ||
| sd = self._merge_state_dicts(sd_list) | ||
| t_list.append(sd) | ||
| return t_list | ||
|
|
||
| def get_final_norm_state(self, tp_index:int) -> Dict: | ||
| assert tp_index in self.tp_to_final_norm_map.keys() | ||
| sd = torch.load(self.tp_to_final_norm_map[tp_index][0], map_location=torch.device('cpu')) | ||
| return sd | ||
|
|
||
| def _build_tp_other_layer_map(self, layer_index:int): | ||
| assert layer_index < len(self.layer_files) | ||
| layer_files = self._get_files_with_prefix(self.layer_files, self.layer_keys[layer_index]) | ||
| layer_file_partitions = self._partition_data(layer_files, self.tp_degree) | ||
| data_map = {i:flist for i, flist in enumerate(layer_file_partitions)} | ||
| return data_map | ||
|
|
||
| def _build_pp_transformer_map(self): | ||
| data_map = {} | ||
| transformer_layers = self.layer_keys[1:-1] | ||
| layers_per_pp = len(transformer_layers) // self.pp_degree | ||
| data_map = {i:transformer_layers[i*layers_per_pp:(i+1)*layers_per_pp] for i in range(0, self.pp_degree)} | ||
| return data_map | ||
|
|
||
| def _dump_mapping(self, data_map, map_tag = None): | ||
| if map_tag is not None: | ||
| print(f'Dump mapping: {map_tag}') | ||
| for k, v in data_map.items(): | ||
| print(f'{k} = {v}') | ||
|
|
||
| def _build_transformer_file_map(self): | ||
| transformer_layer_keys = self.layer_keys[1:-1] | ||
| file_map = {} | ||
| layers_per_pp = len(transformer_layer_keys) // self.pp_degree | ||
| for key_index, layer_key in enumerate(transformer_layer_keys): | ||
| pp_index = key_index // layers_per_pp | ||
| layer_files = self._get_files_with_prefix(self.layer_files, layer_key) | ||
| layer_file_partitions = self._partition_data(layer_files, self.tp_degree) | ||
| for tp_index in range(self.tp_degree): | ||
| map_key = (tp_index, pp_index) | ||
| if not map_key in file_map.keys(): | ||
| file_map[map_key] = [] | ||
| file_map[map_key].append(layer_file_partitions[tp_index]) | ||
|
|
||
| return file_map | ||
|
|
||
| def _sanity_check(self): | ||
| assert len(self.mp_rank_files) % self.tp_degree == 0 | ||
| assert len(self.zero_files) % (self.pp_degree * self.tp_degree) == 0 | ||
| assert len(self.layer_keys) > 2 | ||
| assert (len(self.layer_keys) - 2) % self.pp_degree == 0 | ||
|
|
||
| def _get_files_with_prefix(self, all_files, prefix): | ||
| file_list = [] | ||
| for file_path in all_files: | ||
| _, fname = os.path.split(file_path) | ||
| if fname.startswith(prefix): | ||
| file_list.append(file_path) | ||
|
|
||
| return sorted(file_list) | ||
|
|
||
| def validate_files(self): | ||
| for file in self.file_list: | ||
| if not os.path.isfile(file): | ||
| print(f'Error: {file} is not existent') | ||
|
|
||
| def _get_files(self, dir): | ||
| file_list = [] | ||
| for root, dirs, files in os.walk(dir): | ||
| for file in files: | ||
| file_list.append(os.path.join(root, file)) | ||
| return file_list | ||
|
|
||
| def _get_layer_keys(self): | ||
| key_set = set() | ||
| key_len = len(LAYER_FILE_PREFIX) + 2 | ||
| for file_path in self.layer_files: | ||
| _, fname = os.path.split(file_path) | ||
| key_set.add(fname[:key_len]) | ||
| return sorted(list(key_set)) | ||
|
|
||
| def _partition_data(self, data_list, num_partitions): | ||
| num_elems = len(data_list) | ||
| assert num_elems % num_partitions == 0 | ||
| partition_size = num_elems // num_partitions | ||
| partitions_list = [data_list[i:i+partition_size] for i in range(0, num_elems, partition_size)] | ||
| return partitions_list | ||
|
|
||
| def _merge_state_dicts(self, sd_list): | ||
| merged_sd = {} | ||
| for key in sd_list[0].keys(): | ||
| if not key in SEQUENTIAL_LAYERS: | ||
| cat_dim = LAYER_CONCAT_DIM.get(key, 0) | ||
| merged_sd[key] = torch.cat([sd[key] for sd in sd_list], dim=cat_dim) | ||
| else: | ||
| merged_sd[key] = sd_list[0][key] | ||
| return merged_sd |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| import argparse | ||
| import os | ||
| import torch | ||
| from collections import OrderedDict | ||
| from deepspeed_checkpoint import ARGS_KEY, DeepSpeedCheckpoint | ||
|
|
||
| MODEL_KEY = 'model' | ||
| ARGS_KEY = 'args' | ||
| LANGUGAGE_MODEL_KEY = 'language_model' | ||
| EMBEDDING_KEY = 'embedding' | ||
| ENCODER_KEY = 'encoder' | ||
| WORD_EMBEDDINGS_FOR_HEAD_KEY = 'word_embeddings_for_head' | ||
| WORD_EMBEDDINGS_KEY = 'word_embeddings' | ||
| FINAL_LAYER_NORM_KEY ='final_layernorm' | ||
| CHECKPOINT_VERSION_KEY = 'checkpoint_version' | ||
| CHECKPOINT_VERSION_VALUE = 3.0 | ||
| ITERATION_KEY = 'iteration' | ||
|
|
||
| def parse_arguments(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument('--input_folder', default=None, type=str, help='Input DeepSpeed Checkpoint folder') | ||
| parser.add_argument('--output_folder', default=None, type=str, help='Output Megatron checkpoint folder') | ||
| parser.add_argument('--target_tp', default=None, type=int, help='Target TP degree') | ||
| parser.add_argument('--target_pp', default=None, type=int, help='Target PP degree') | ||
| parser.add_argument('--for_release', action='store_true', help='Convert for release purpose, reset some (progress) counters.') | ||
| args = parser.parse_args() | ||
| print(f'args = {args}') | ||
| return args | ||
|
|
||
|
|
||
| def _convert_ds_transformer_state(sd_list): | ||
| new_sd = OrderedDict() | ||
| for i, sd in enumerate(sd_list): | ||
| for key, value in sd.items(): | ||
| new_key = f'layers.{i}.{key}' | ||
| new_sd[new_key] = value | ||
|
|
||
| return new_sd | ||
|
|
||
| def _create_checkpoint_paths(base_folder, iteration, tp_degree, pp_degree): | ||
| path_list = [] | ||
| iter_folder = f'iter_{iteration:07d}' | ||
| for i in range(0, tp_degree): | ||
| path_list.append([]) | ||
| for j in range(0, pp_degree): | ||
| rank_folder = f'mp_rank_{i:02d}' if pp_degree == 1 else f'mp_rank_{i:02d}_{j:03d}' | ||
| ckpt_path = os.path.join(rank_folder, 'model_optim_rng.pt') | ||
| path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path)) | ||
|
|
||
| return path_list | ||
|
|
||
|
|
||
| def _create_megatron_dict(): | ||
| language_model_dict = { | ||
| EMBEDDING_KEY: {}, | ||
| ENCODER_KEY: {} | ||
| } | ||
| megatron_dict = { | ||
| MODEL_KEY: {LANGUGAGE_MODEL_KEY: language_model_dict}, | ||
| CHECKPOINT_VERSION_KEY: CHECKPOINT_VERSION_VALUE | ||
| } | ||
| return megatron_dict | ||
|
|
||
|
|
||
| def _save_checkpoint(file_path, chkpt_sd): | ||
| dir, _ = os.path.split(file_path) | ||
| os.makedirs(dir, exist_ok=True) | ||
| torch.save(chkpt_sd, file_path) | ||
|
|
||
|
|
||
| def _renest_sd(sd): | ||
| new_sd = OrderedDict() | ||
| for key, value in sd.items(): | ||
| a, b = key.split('.') | ||
| new_sd[a] = {b: value} | ||
| return new_sd | ||
|
|
||
|
|
||
| def _create_rank_checkpoint(ds_checkpoint, checkpoint_path, tp_index, pp_index, for_release=False): | ||
| meg_encoder_sd = OrderedDict() | ||
| meg_embedding_sd = OrderedDict() | ||
| meg_embedding_for_head_sd = OrderedDict() | ||
|
|
||
| transformer_sd = ds_checkpoint.get_transformer_state(tp_index, pp_index) | ||
| meg_encoder_sd.update(_convert_ds_transformer_state(transformer_sd)) | ||
|
|
||
| if pp_index in [0, ds_checkpoint.pp_degree - 1]: | ||
| embedding_sd = ds_checkpoint.get_embedding_state(tp_index) | ||
| nested_embedding_sd = _renest_sd(embedding_sd) | ||
| if pp_index == 0: | ||
| meg_embedding_sd.update(nested_embedding_sd) | ||
|
|
||
| if pp_index == ds_checkpoint.pp_degree -1: | ||
| for key, value in embedding_sd.items(): | ||
| if key.startswith(WORD_EMBEDDINGS_KEY): | ||
| fields = key.split('.') | ||
| new_fields = fields[1:] | ||
| new_key = '.'.join(new_fields) | ||
| meg_embedding_for_head_sd[new_key] = value | ||
|
|
||
| final_norm_sd = ds_checkpoint.get_final_norm_state(tp_index) | ||
| new_final_norm_sd = {f'{FINAL_LAYER_NORM_KEY}.{key}': value for key, value in final_norm_sd.items()} | ||
| meg_encoder_sd.update(new_final_norm_sd) | ||
|
|
||
| checkpoint_sd = _create_megatron_dict() | ||
|
|
||
| iteration = ds_checkpoint.get_iteration() | ||
| checkpoint_sd[ITERATION_KEY] = iteration | ||
| if pp_index == 0: | ||
| checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][EMBEDDING_KEY] = meg_embedding_sd | ||
| checkpoint_sd[MODEL_KEY][LANGUGAGE_MODEL_KEY][ENCODER_KEY] = meg_encoder_sd | ||
| if pp_index == ds_checkpoint.pp_degree -1: | ||
| checkpoint_sd[MODEL_KEY][WORD_EMBEDDINGS_FOR_HEAD_KEY] = meg_embedding_for_head_sd | ||
|
|
||
| checkpoint_sd[ARGS_KEY] = ds_checkpoint.get_args() | ||
| # Adjust specific fields | ||
| checkpoint_sd[ARGS_KEY].tensor_model_parallel_size = ds_checkpoint.tp_degree | ||
| checkpoint_sd[ARGS_KEY].pipeline_model_parallel_size = ds_checkpoint.pp_degree | ||
| if for_release: | ||
| checkpoint_sd[ARGS_KEY].consumed_train_samples = 0 | ||
| checkpoint_sd[ARGS_KEY].consumed_valid_samples = 0 | ||
|
|
||
| _save_checkpoint(checkpoint_path, checkpoint_sd) | ||
|
|
||
|
|
||
| def _create_latest_file(base_folder, iteration): | ||
| file_path = os.path.join(base_folder, 'latest_checkpointed_iteration.txt') | ||
| os.makedirs(base_folder, exist_ok=True) | ||
| with open(file_path, 'w') as f: | ||
| f.write(str(iteration)) | ||
|
|
||
| def main(): | ||
| print(f'Convert DeepSpeed Checkpoint to Megatron Checkpoint') | ||
|
|
||
| args = parse_arguments() | ||
| print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Megatron checkpoint in {args.output_folder}') | ||
|
|
||
| ds_checkpoint = DeepSpeedCheckpoint(args.input_folder, args.target_tp, args.target_pp) | ||
| iteration = ds_checkpoint.get_iteration() | ||
| _create_latest_file(args.output_folder, iteration) | ||
| checkpoint_paths = _create_checkpoint_paths(args.output_folder, iteration, ds_checkpoint.tp_degree, ds_checkpoint.pp_degree) | ||
| for i in range(0, ds_checkpoint.tp_degree): | ||
| for j in range(0, ds_checkpoint.pp_degree): | ||
| _create_rank_checkpoint(ds_checkpoint, checkpoint_paths[i][j], i, j, args.for_release) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.