From e5218b374c7f5fd1b01a12c8db59fa88bde913fd Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Fri, 28 Apr 2023 16:30:42 -0700 Subject: [PATCH] Add interleaved pp support (#6498) * Add support for Virtual Pipeline Parallel conversion Signed-off-by: smajumdar * Add support for Virtual Pipeline Parallel conversion Signed-off-by: smajumdar * Switch to megatron core Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: smajumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../megatron_change_num_partitions.py | 385 ++++++++++++++---- 1 file changed, 313 insertions(+), 72 deletions(-) diff --git a/examples/nlp/language_modeling/megatron_change_num_partitions.py b/examples/nlp/language_modeling/megatron_change_num_partitions.py index 944565d8bd43..a4b28fa4d761 100644 --- a/examples/nlp/language_modeling/megatron_change_num_partitions.py +++ b/examples/nlp/language_modeling/megatron_change_num_partitions.py @@ -13,11 +13,13 @@ # limitations under the License. import os +import tempfile from argparse import ArgumentParser from typing import Dict, List import torch -from omegaconf import open_dict +import torch.nn as nn +from omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer from nemo.collections.nlp.parts.nlp_overrides import ( @@ -54,6 +56,20 @@ --target_pipeline_model_parallel_size=1 \ --target_pipeline_model_parallel_split_rank=0 \ --precision=bf16 + +# Megatron GPT + Virtual Pipeline parallelism + +python megatron_change_num_partitions.py \ + --model_extracted_dir="" \ + --target_file="" \ + --ckpt_name="" \ + --tensor_model_parallel_size= \ + --target_tensor_model_parallel_size= \ + --pipeline_model_parallel_size= \ + --target_pipeline_model_parallel_size= \ + --virtual_pipeline_model_parallel_size= \ + --hparams_file="" \ + --precision=bf16 ### Only Tensor Parallelism conversion ### @@ -100,6 +116,43 @@ """ +def set_virtual_parallel_rank_safely(rank: int): + AppState().virtual_pipeline_model_parallel_rank = rank + + try: + from megatron.core import parallel_state + + parallel_state.set_virtual_pipeline_model_parallel_rank(rank) + + if rank is None: + parallel_state.set_virtual_pipeline_model_parallel_world_size(0) + + except (ImportError, ModuleNotFoundError): + logging.warning("`megatron-core` not installed, cannot set virtual parallel rank !") + + +################# +### Utilities ### +################# + + +def force_cpu_model(cfg): + with open_dict(cfg): + # temporarily + original_cpu_init = cfg.get('use_cpu_initialization', False) + original_amp_o2 = cfg.get('megatron_amp_O2', False) + cfg.use_cpu_initialization = True + cfg.megatron_amp_O2 = False + return cfg, {'original_cpu_init': original_cpu_init, 'original_amp_o2': original_amp_o2} + + +def restore_model_config(cfg, original_dict): + with open_dict(cfg): + for key, val in original_dict.items(): + cfg[key] = val + return cfg + + ################# ### Utilities ### ################# @@ -732,6 +785,12 @@ def main(): parser.add_argument( '--target_pipeline_model_parallel_split_rank', type=int, default=0, help='PP rank to split for Enc-Dec models' ) + parser.add_argument( + '--virtual_pipeline_model_parallel_size', type=int, default=None, help='Virtual Pipeline parallelism size' + ) + parser.add_argument( + '--ckpt_name', type=str, default=None, help='Checkpoint name to load from for Virtual Parallel' + ) parser.add_argument( "--model_class", type=str, @@ -759,6 +818,7 @@ def main(): default=None, help="Path to the tokenizer model path if your model uses a tokenizer model as an artifact. This is needed if your model uses a sentencepiece tokenizer.", ) + parser.add_argument('--hparams_file', type=str, default=None, help='Path to hparams file from PTL training') parser.add_argument('--tp_conversion_only', action='store_true', help='Only convert TP model to TP model') parser.add_argument('--model_extracted_dir', type=str, default=None, help='Path to pre-extracted model directory') @@ -795,6 +855,25 @@ def main(): pp_size = args.pipeline_model_parallel_size tgt_pp_size = args.target_pipeline_model_parallel_size pipeline_model_parallel_split_rank = args.target_pipeline_model_parallel_split_rank + vp_size = args.virtual_pipeline_model_parallel_size + if vp_size is None: + vp_size = 1 + + convert_vp = vp_size > 1 + if convert_vp: + hparams_filepath = args.hparams_file + if hparams_filepath is None: + logging.warning( + '\n\n\n!!!!!!!!!\n' + 'You are converting a model with virtual pipeline parallelism enabled, \n' + 'but have not passed `hparams_file` argument. \n' + 'This will cause each ckpt file to be temporarily laoded onto GPU memory!\n\n' + 'It is highly recommended to pass `hparams_file` argument to avoid this.\n' + ) + else: + hparams_filepath = None + + # Import the class of the model cls = model_utils.import_class_by_path(args.model_class) if args.model_file is None and args.model_extracted_dir is None: @@ -830,10 +909,16 @@ def main(): tgt_pp_size = 1 pipeline_model_parallel_split_rank = 0 + if vp_size is None or vp_size < 0: + vp_size = 1 + app_state = AppState() app_state.data_parallel_rank = 0 app_state.pipeline_model_parallel_size = pp_size app_state.tensor_model_parallel_size = tp_size + + if vp_size > 1: + app_state.virtual_pipeline_model_parallel_size = vp_size app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size world_size = pp_size * tp_size # pseudo world size for simulating load of a specific rank on a single gpu @@ -841,87 +926,198 @@ def main(): app_state.tensor_model_parallel_rank = 0 app_state.pipeline_model_parallel_rank = 0 + if vp_size > 1: + set_virtual_parallel_rank_safely(0) + # If input model has TP > 1 or PP > 1 # Reconstruct the model to have TP = 1 and PP = 1 # Note that this is a forward loop that will process PP [0..N] TP [0..M] in sequential order. if tp_size > 1 or pp_size > 1: - partitions = {} + partitions = {} # 3d list of VP x PP x TP model = None - for pp_rank in range(pp_size): - app_state.pipeline_model_parallel_rank = pp_rank - partitions[pp_rank] = [] - - for tp_rank in range(tp_size): - app_state.tensor_model_parallel_rank = tp_rank - - logging.info(f"Loading ------------ PP Rank: {pp_rank} TP Rank: {tp_rank}") - - # Override flag that forces Model to use AppState instead of Trainer - # to determine the world size, global and local rank - # Used for simulating load of a specific rank on a single gpu - os.environ[NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE] = "true" - - # Compute the global rank to load the correct subset of parameters - global_rank = pp_rank * tp_size + tp_rank - - # Update AppState - app_state.world_size = world_size - app_state.global_rank = global_rank - app_state.local_rank = global_rank % num_gpu_per_node - app_state.pipeline_model_parallel_size = pp_size - app_state.tensor_model_parallel_size = tp_size - app_state.pipeline_model_parallel_split_rank = pipeline_model_parallel_split_rank - app_state.model_parallel_size = ( - app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size - ) - save_restore_connector = NLPSaveRestoreConnector() + # Build partitions structure + for vp_idx in range(vp_size): + partitions[vp_idx] = [] # Build first layer - VP - if args.model_extracted_dir is not None: - logging.info(f"Using extracted model directory: {args.model_extracted_dir}") - save_restore_connector.model_extracted_dir = args.model_extracted_dir + for pp_idx in range(pp_size): + # For each VP, build PP x TP holder + partitions[vp_idx].append({}) + partitions[vp_idx][pp_idx] = [] - if args.model_file is not None: - model_filepath = args.model_file - else: - model_filepath = args.model_extracted_dir + for vp_rank in range(vp_size): + if vp_size > 1: + set_virtual_parallel_rank_safely(vp_rank) - model = cls.restore_from( - restore_path=model_filepath, - trainer=trainer, - map_location=torch.device("cpu"), - save_restore_connector=save_restore_connector, - ) - model.to(dtype=dtype) + for pp_rank in range(pp_size): + app_state.pipeline_model_parallel_rank = pp_rank - # Reset env flag - os.environ.pop(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, None) + for tp_rank in range(tp_size): + app_state.tensor_model_parallel_rank = tp_rank - logging.info( - f"<<<<<<<< LOADED MODEL PP={pp_rank + 1} TP={tp_rank + 1} | " - f"GLOBAL RANK = {global_rank} >>>>>>>>>" - ) - params = [p for _, p in model.named_parameters()] - partitions[pp_rank].append(params) + logging.info(f"Loading ------------ PP Rank: {pp_rank} TP Rank: {tp_rank}") - # app_state is being updated incorrectly during restore - app_state.data_parallel_rank = 0 - app_state.pipeline_model_parallel_rank = pp_rank - app_state.tensor_model_parallel_rank = tp_rank - app_state.pipeline_model_parallel_size = pp_size - app_state.tensor_model_parallel_size = tp_size - app_state.model_parallel_size = ( - app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size - ) + # Override flag that forces Model to use AppState instead of Trainer + # to determine the world size, global and local rank + # Used for simulating load of a specific rank on a single gpu + os.environ[NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE] = "true" + + # Compute the global rank to load the correct subset of parameters + global_rank = pp_rank * tp_size + tp_rank + + # Update AppState + app_state.world_size = world_size + app_state.global_rank = global_rank + app_state.local_rank = global_rank % num_gpu_per_node + app_state.pipeline_model_parallel_size = pp_size + app_state.tensor_model_parallel_size = tp_size + app_state.pipeline_model_parallel_split_rank = pipeline_model_parallel_split_rank + app_state.model_parallel_size = ( + app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size + ) + + if vp_size > 1: + set_virtual_parallel_rank_safely(vp_rank) + + if vp_rank == 0: + save_restore_connector = NLPSaveRestoreConnector() + + if args.model_extracted_dir is not None: + logging.info(f"Using extracted model directory: {args.model_extracted_dir}") + save_restore_connector.model_extracted_dir = args.model_extracted_dir + + if args.model_file is not None: + model_filepath = args.model_file + else: + model_filepath = args.model_extracted_dir + + if vp_size == 1: + + # Get model config + tmp_cfg = cls.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + return_config=True, + ) + + # Force model onto CPU + tmp_cfg, restore_dict = force_cpu_model(tmp_cfg) + + # Restore model + model = cls.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + override_config_path=tmp_cfg, + ) + model.freeze() + + # Restore model config + restore_model_config(model.cfg, restore_dict) + + else: + if args.ckpt_name is None: + raise ValueError( + "For Virtual Parallel, ckpt name is required.\n" + "Please provide `--ckpt_name` argument." + ) + + # inject model parallel rank + checkpoint_path = model_utils.inject_model_parallel_rank( + os.path.join(model_filepath, args.ckpt_name) + ) + + if hparams_filepath is not None: + # Force the model onto CPU + tmp_cfg = OmegaConf.load(hparams_filepath) + tmp_cfg, restore_dict = force_cpu_model(tmp_cfg) + + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', suffix='.yml') as tmp: + OmegaConf.save(tmp_cfg, tmp, resolve=True) + tmp.seek(0) + + model = cls.load_from_checkpoint( + checkpoint_path=checkpoint_path, + trainer=trainer, + map_location=torch.device("cpu"), + hparams_file=tmp.name, + ) + model.freeze() + + restore_model_config(model.cfg, restore_dict) + + else: + model = cls.load_from_checkpoint( + checkpoint_path=checkpoint_path, trainer=trainer, map_location=torch.device("cpu"), + ) + model.freeze() + + model.to(dtype=dtype) + + # Reset env flag + os.environ.pop(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, None) + + logging.info( + f"<<<<<<<< LOADED MODEL PP={pp_rank + 1} TP={tp_rank + 1} | " + f"GLOBAL RANK = {global_rank} >>>>>>>>>" + ) + + # Save the parameters + if vp_size == 1: + params = [p for p in model.parameters()] + partitions[vp_rank][pp_rank].append(params) # vp_rank = 0 + + else: + vp_params_tmp = [] + for vp_idx in range(vp_size): + set_virtual_parallel_rank_safely(vp_idx) + params = [p for p in model.model[vp_idx].parameters()] + # params = model.model[vp_idx].module.state_dict_for_save_checkpoint() + # params = [p for p in params.values()] + vp_params_tmp.append(params) + # partitions[pp_rank][vp_idx].append(params) + + for vp_idx in range(vp_size): + partitions[vp_idx][pp_rank].append(vp_params_tmp[vp_idx]) + + del vp_params_tmp + set_virtual_parallel_rank_safely(0) + + # app_state is being updated incorrectly during restore + app_state.data_parallel_rank = 0 + app_state.pipeline_model_parallel_rank = pp_rank + app_state.tensor_model_parallel_rank = tp_rank + app_state.pipeline_model_parallel_size = pp_size + app_state.tensor_model_parallel_size = tp_size + app_state.model_parallel_size = ( + app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size + ) + + if vp_size > 1: + app_state.virtual_pipeline_model_parallel_size = vp_size + set_virtual_parallel_rank_safely(vp_rank) # Build a unified model with PP 1 TP 1 with open_dict(model.cfg): model.cfg.tensor_model_parallel_size = 1 model.cfg.pipeline_model_parallel_size = 1 + model.cfg.virtual_pipeline_model_parallel_size = None + + app_state.global_rank = 0 + app_state.local_rank = 0 + app_state.data_parallel_rank = 0 + app_state.pipeline_model_parallel_rank = 0 app_state.tensor_model_parallel_rank = 0 - app_state.pipeline_model_parallel_size = 0 + app_state.pipeline_model_parallel_size = 1 + app_state.tensor_model_parallel_size = 1 app_state.model_parallel_size = 1 + if vp_size > 1: + set_virtual_parallel_rank_safely(None) + trainer = Trainer(devices=1, strategy=NLPDDPStrategy(), accelerator="cpu", precision=precision) with open_dict(model.cfg): @@ -930,25 +1126,52 @@ def main(): if args.tokenizer_vocab_file is not None: model.cfg.tokenizer.vocab_file = args.tokenizer_vocab_file - # temporarily - original_cpu_init = model.cfg.get('use_cpu_initialization', False) - original_amp_o2 = model.cfg.get('megatron_amp_O2', False) - model.cfg.use_cpu_initialization = True - model.cfg.megatron_amp_O2 = False + model.cfg, restore_dict = force_cpu_model(model.cfg) - model = cls(model.cfg, trainer) + # Remove Virtual Parallelism + model.cfg.virtual_pipeline_model_parallel_size = None + + logging.info(f"<<<<<<<< Building TP 1 PP 1 base model >>>>>>>>>") + model = cls(model.cfg, trainer) # type: nn.Module + model.freeze() model = model.to('cpu') model._save_restore_connector = NLPSaveRestoreConnector() + vp_param_count = 0 + for vp in range(vp_size): + for pp in range(pp_size): + for tp in range(tp_size): + vp_param_count += len(partitions[vp][pp][tp]) + + if vp_size > 1: + logging.debug(f"Total params in TP PP VP = 1 : {len(list(model.parameters()))}") + logging.debug(f"Total params in VP PP TP (og): {vp_param_count}") + + # Flatten Virtual Pipeline + if vp_size == 1: + # unpack vp container, pack pp tp container + partitions = partitions[0] + partitions = {idx: val for idx, val in enumerate(partitions)} + else: + flat_partitions = {idx: [] for idx in range(pp_size)} + + for pp in range(pp_size): + for tp in range(tp_size): + vp_cache = [] + for vp in range(vp_size): + vp_cache.extend(partitions[vp][pp][tp]) + + flat_partitions[pp].append(vp_cache) + + partitions = flat_partitions + if tgt_tp_size > 1 or tgt_pp_size > 1: merge_partition(model, partitions) else: # Write out the PP 1 TP 1 model to disk merge_partition(model, partitions, args.target_file) - with open_dict(model.cfg): - model.cfg.use_cpu_initialization = original_cpu_init - model.cfg.megatron_amp_O2 = original_amp_o2 + restore_model_config(model.cfg, restore_dict) # Empty cache memory of all parameters from all PP TP partitions partitions.clear() @@ -968,6 +1191,16 @@ def main(): else: model_filepath = args.model_extracted_dir + tmp_cfg = cls.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + return_config=True, + ) + + tmp_cfg, restore_dict = force_cpu_model(tmp_cfg) + model = cls.restore_from( restore_path=model_filepath, trainer=trainer, @@ -976,6 +1209,8 @@ def main(): ) model.to(dtype=dtype) + restore_model_config(model.cfg, restore_dict) + # If target model has TP > 1 or PP > 1 if tgt_pp_size > 1 or tgt_tp_size > 1: @@ -1046,10 +1281,16 @@ def main(): with open_dict(model.cfg): model.cfg.tokenizer.model = args.tokenizer_model_path - model = cls(model.cfg, trainer).to('cpu') + model.cfg, restore_dict = force_cpu_model(model.cfg) + + model = cls(model.cfg, trainer) + model = model.to('cpu') model._save_restore_connector = NLPSaveRestoreConnector() + model.freeze() model.to(dtype=dtype) + restore_model_config(model.cfg, restore_dict) + # Update global batch size if old_global_batch_size % new_global_batch_size != 0 or old_global_batch_size < new_global_batch_size: logging.info(