From ae5d7e81b8e446e5650082b1700eb92dd2e7c1bd Mon Sep 17 00:00:00 2001 From: Igor Gitman Date: Sun, 3 Dec 2023 15:00:42 -0800 Subject: [PATCH 01/12] Pass in rotary_base to mcore and from HF (#7933) * Pass in rotary_base to mcore and from HF Signed-off-by: Igor Gitman * Allow changing rotary_base from the sft config file Signed-off-by: Igor Gitman * Update mcore in jenkins Signed-off-by: Igor Gitman --------- Signed-off-by: Igor Gitman Co-authored-by: Eric Harper --- Jenkinsfile | 4 ++-- examples/nlp/language_modeling/tuning/megatron_gpt_sft.py | 3 +++ .../nlp/models/language_modeling/megatron_gpt_model.py | 1 + scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py | 2 ++ 4 files changed, 8 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 1f974333dd3a..12fafac57a67 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -72,8 +72,8 @@ pipeline { steps { sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout e122536b7645edcb7ebf099b5c92a443f7dbf8e7 && \ - pip install -e .' + git checkout 973330e9c3681604703bf1eb6b5a265d1b9b9b38 && \ + pip install .' } } diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py index 79dd20fcf84a..b6325be40829 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py @@ -90,6 +90,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): if cfg.model.get('seq_len_interpolation_factor', None) is not None: gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor + if cfg.model.get('rotary_base', None) is not None: + gpt_cfg.rotary_base = cfg.model.rotary_base + sft_cls = MegatronGPTSFTModel gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}" diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 5b14532016c5..c2e39ea03a3e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -318,6 +318,7 @@ def model_provider_func(self, pre_process, post_process): position_embedding_type=self.cfg.get('position_embedding_type', 'learned_absolute'), rotary_percent=self.cfg.get('rotary_percentage', 1.0), seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None), + rotary_base=self.cfg.get('rotary_base', 10000), ) else: assert self.cfg.get('num_query_groups', None) is None or self.cfg.get( diff --git a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py index c281088f8c5c..d1453aeee972 100644 --- a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py @@ -116,6 +116,8 @@ def load_config(args, llama_config): nemo_config['seq_len_interpolation_factor'] = llama_config['rope_scaling']['factor'] else: raise ValueError("Only linear rope scaling type is supported now") + if llama_config['rope_theta'] is not None: + nemo_config['rotary_base'] = llama_config['rope_theta'] base = 128 while llama_config['vocab_size'] % base != 0: From 110c9d738d5a885e00824ea6559d1df8e20370ce Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Mon, 4 Dec 2023 11:23:07 -0800 Subject: [PATCH 02/12] Add interface to set NCCL options of each process group (#7923) Signed-off-by: Sangkug Lym Co-authored-by: Eric Harper --- examples/nlp/language_modeling/conf/megatron_gpt_config.yaml | 1 + nemo/collections/nlp/parts/megatron_trainer_builder.py | 1 + nemo/collections/nlp/parts/nlp_overrides.py | 5 ++++- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index bd34d54f5fd6..8a7dd689e970 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -131,6 +131,7 @@ model: apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) sync_batch_comm: False # Enable stream synchronization after each p2p communication between pipeline stages + nccl_communicator_config_path: null # Path to the yaml file with NCCL communicator options (min_ctas, max_ctas, and cga_cluster_size) ## Activation Checkpointing # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index b2554a35cdbd..69956129bdde 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -53,6 +53,7 @@ def _training_strategy(self) -> NLPDDPStrategy: no_ddp_communication_hook=True, gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view, find_unused_parameters=False, + nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None), ) def _grad_scaler(self) -> GradScaler: diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 0d13c31d9965..82cdae381701 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -81,6 +81,7 @@ class NLPDDPStrategy(DDPStrategy): Args: no_ddp_communication_hook: Disable DDP communication hook when using AMP-O2 with FP32 gradient accumulation. + nccl_communicator_config_path: Path to the yaml file with NCCL communicator options """ def __init__( @@ -89,6 +90,7 @@ def __init__( cluster_environment: ClusterEnvironment = None, checkpoint_io: Optional[CheckpointIO] = None, no_ddp_communication_hook: bool = False, + nccl_communicator_config_path: Optional[str] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: if not HAVE_APEX: @@ -103,6 +105,7 @@ def __init__( super().__init__(parallel_devices, cluster_environment, checkpoint_io, **kwargs) self.no_ddp_communication_hook = no_ddp_communication_hook + self.nccl_communicator_config_path = nccl_communicator_config_path def setup(self, trainer: "pl.Trainer") -> None: """ @@ -180,7 +183,6 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None: Args: global_rank (int): the global process index. world_size (int): the total number of GPUs, num_nodes * num_devices - is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM. """ app_state = AppState() @@ -196,6 +198,7 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None: pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, + nccl_communicator_config_path=self.nccl_communicator_config_path, ) # assert that fake tp and pp rank match after model parallel init From 52d50e9e09a3e636d60535fd9882f3b3f32f92ad Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Mon, 4 Dec 2023 22:59:37 -0500 Subject: [PATCH 03/12] Support O2 training of PEFT and SFT (#7971) * support O2 Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py index 16a3850852d4..853ffc6ea012 100644 --- a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py +++ b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py @@ -67,7 +67,7 @@ def __init__(self, *args, **kwargs): self.use_ptuning_only = False super().__init__(*args, **kwargs) if hasattr(self, "enc_dec_model"): - self.model_prefix = "enc_dec_model." # for T5 + self.model_prefix = "enc_dec_model.module." if self.cfg.megatron_amp_O2 else "enc_dec_model." # for T5 else: self.model_prefix = "model.module." if self.cfg.megatron_amp_O2 else "model." @@ -351,7 +351,7 @@ def sharded_state_dict(self, prefix: str = ''): if not use_mcore_gpt or (self.use_peft and self.setup_complete): return None else: - return self.model.sharded_state_dict(prefix=self.model_prefix) + return super().sharded_state_dict(prefix=prefix) def load_state_dict(self, state_dict, strict: bool = True): if len(state_dict) == 0: From f733f54f13be90062b9ffdf5f24da467ddb0cd7b Mon Sep 17 00:00:00 2001 From: Eric Harper Date: Wed, 6 Dec 2023 17:04:49 -0700 Subject: [PATCH 04/12] Add news section to README (#7984) * add news Signed-off-by: eharper * use rst Signed-off-by: eharper * fix image Signed-off-by: eharper * update image Signed-off-by: eharper * update Signed-off-by: eharper * update scale Signed-off-by: eharper * revert Signed-off-by: eharper * update width Signed-off-by: eharper --------- Signed-off-by: eharper --- README.rst | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/README.rst b/README.rst index d07b07434b20..fba4aaf04f09 100644 --- a/README.rst +++ b/README.rst @@ -38,6 +38,22 @@ **NVIDIA NeMo** =============== +Latest News +----------- + +- 2023/12/06 `New NVIDIA NeMo Framework Features and NVIDIA H200 `_ + +.. image:: https://github.com/sbhavani/TransformerEngine/blob/main/docs/examples/H200-NeMo-performance.png + :target: https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility + :alt: H200-NeMo-performance + :width: 600 + +NeMo Framework has been updated with state-of-the-art features, +such as FSDP, Mixture-of-Experts, and RLHF with TensorRT-LLM to provide speedups up to 4.2x for Llama-2 pre-training on H200. +**All of these features will be available in an upcoming release.** + + + Introduction ------------ From bbadcf7115baa9177ff35c9c5b3186c65962dc3a Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Thu, 7 Dec 2023 09:21:09 +0100 Subject: [PATCH 05/12] [NLP] Access scaler only in FP16 case (#7916) * Remove unused 'precision' variable Signed-off-by: Jan Lasek * Raise informative error when trying to load FP16 model for trainer.precision != 16 Signed-off-by: Jan Lasek * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make sure scaler is available instead of raising error Signed-off-by: Jan Lasek * trainer != None is assured thanks to a previous check Signed-off-by: Jan Lasek --------- Signed-off-by: Jan Lasek Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../nlp/models/language_modeling/megatron_base_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 6579e837b1a6..ccdd2e8725db 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -103,7 +103,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): self.tokenizer = None with open_dict(cfg): - if cfg.get('precision', None) is None and trainer is not None: + if cfg.get('precision', None) is None: cfg.precision = trainer.precision super().__init__(cfg, trainer=trainer, no_lm_init=no_lm_init) @@ -773,7 +773,6 @@ def build_model_parallel_config(self) -> ModelParallelConfig: cfg = OmegaConf.to_container(self.cfg, resolve=True) # map precision related configs - precision = cfg.get('precision', 32) # PTL trainer precision megatron_amp_O2 = cfg.get('megatron_amp_O2', False) # dtype used in p2p communication @@ -791,7 +790,7 @@ def build_model_parallel_config(self) -> ModelParallelConfig: and not self.cfg.get('sequence_parallel', False), "pipeline_dtype": pipeline_dtype, "grad_scale_func": self.trainer.precision_plugin.scaler.scale - if self.torch_dtype == torch.float16 + if self.trainer.precision in ["16", "16-mixed"] else None, "enable_autocast": not megatron_amp_O2 and self.torch_dtype in [torch.bfloat16, torch.float16], "autocast_dtype": self.autocast_dtype, From c822d5ca391bb468d46cad28e8e96db02ef8dccc Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 7 Dec 2023 10:27:37 -0800 Subject: [PATCH 06/12] fix librosa display issue (#7991) (#7993) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao --- tutorials/asr/ASR_with_NeMo.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/asr/ASR_with_NeMo.ipynb b/tutorials/asr/ASR_with_NeMo.ipynb index 74cd0f739e84..8a5a000d79bb 100644 --- a/tutorials/asr/ASR_with_NeMo.ipynb +++ b/tutorials/asr/ASR_with_NeMo.ipynb @@ -267,7 +267,7 @@ "plt.title('Waveform of Audio Example')\n", "plt.ylabel('Amplitude')\n", "\n", - "_ = librosa.display.waveshow(audio)" + "_ = librosa.display.waveshow(audio, color='blue')" ], "execution_count": null, "outputs": [] From 25f066f4e47ac5da9af48223b9a0d108b36dfd4d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 7 Dec 2023 10:29:09 -0800 Subject: [PATCH 07/12] Fix librosa issue (#7994) (#7995) Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar --- tutorials/asr/ASR_with_NeMo.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/asr/ASR_with_NeMo.ipynb b/tutorials/asr/ASR_with_NeMo.ipynb index 8a5a000d79bb..afda092b8ecc 100644 --- a/tutorials/asr/ASR_with_NeMo.ipynb +++ b/tutorials/asr/ASR_with_NeMo.ipynb @@ -330,7 +330,7 @@ }, "source": [ "# Plot the mel spectrogram of our sample\n", - "mel_spec = librosa.feature.melspectrogram(audio, sr=sample_rate)\n", + "mel_spec = librosa.feature.melspectrogram(y=audio, sr=sample_rate)\n", "mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)\n", "\n", "librosa.display.specshow(\n", From 88d3a4d71f6b97a5eb10a4e43fb07df8aee32861 Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Thu, 7 Dec 2023 21:57:32 +0100 Subject: [PATCH 08/12] Minor fixes (#7978) Signed-off-by: Jan Lasek --- .../nlp_language_modeling/convert_hf_llama_to_nemo.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py index d1453aeee972..f6fd0dedd94d 100644 --- a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py @@ -17,8 +17,7 @@ Example to run this conversion script: python convert_hf_llama_to_nemo.py \ --in-file \ - --out-file \ - [--fast-swiglu\ + --out-file """ import os @@ -50,7 +49,7 @@ def get_args(): "--in-file", type=str, default=None, required=True, help="Path to Huggingface LLaMA checkpoints", ) parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output .nemo file.") - parser.add_argument("--precision", type=str, default="32", help="Model precision") + parser.add_argument("--precision", type=str, default="16", help="Model precision") args = parser.parse_args() return args @@ -94,7 +93,7 @@ def load_model(cls, checkpoint, strict, **kwargs): return model -def load_config(args, llama_config): +def load_config(llama_config): nemo_config = OmegaConf.load( os.path.join(os.path.dirname(__file__), '../../examples/nlp/language_modeling/conf/megatron_llama_config.yaml') ).model @@ -138,7 +137,7 @@ def convert(args): for name, param in model.named_parameters(): print(f"- {name}") - nemo_config = load_config(args, hf_config) + nemo_config = load_config(hf_config) if args.precision in ["32", "16"]: precision = int(float(args.precision)) From 5103a9adb9c6f2b72b226b6c078c6b3f202a0cd5 Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Thu, 7 Dec 2023 21:58:28 +0100 Subject: [PATCH 09/12] Resolve dtype with utils_funcs.py (#7979) Signed-off-by: Jan Lasek --- nemo/collections/nlp/parts/utils_funcs.py | 8 +++++++- .../nlp_language_modeling/convert_hf_llama_to_nemo.py | 11 ++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/nemo/collections/nlp/parts/utils_funcs.py b/nemo/collections/nlp/parts/utils_funcs.py index 5185c6cf9b5a..2ec77faf91f5 100644 --- a/nemo/collections/nlp/parts/utils_funcs.py +++ b/nemo/collections/nlp/parts/utils_funcs.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['list2str', 'tensor2list', 'plot_confusion_matrix', 'get_classification_report'] +__all__ = [ + 'torch_dtype_from_precision', + 'list2str', + 'tensor2list', + 'plot_confusion_matrix', + 'get_classification_report', +] import os import time diff --git a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py index f6fd0dedd94d..d6007aa771c0 100644 --- a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py @@ -40,6 +40,7 @@ NLPSaveRestoreConnector, PipelineMixedPrecisionPlugin, ) +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import logging @@ -169,15 +170,6 @@ def convert(args): else: plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) - if precision == 32: - dtype = torch.float32 - elif precision in [16, "16", "16-mixed"]: - dtype = torch.float16 - elif precision in ["bf16", "bf16-mixed"]: - dtype = torch.bfloat16 - else: - dtype = torch.float32 # fallback - nemo_config.precision = precision print(f"nemo_config: {nemo_config}") @@ -314,6 +306,7 @@ def convert(args): model._save_restore_connector = NLPSaveRestoreConnector() # cast to target precision and disable cpu init + dtype = torch_dtype_from_precision(precision) model = model.to(dtype=dtype) model.cfg.use_cpu_initialization = False From 663bd0a23baaced8898fffee090cf84682e5dba8 Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Thu, 7 Dec 2023 22:00:49 +0100 Subject: [PATCH 10/12] Remove replace_sampler_ddp (deprecated in Trainer) (#7981) Signed-off-by: Jan Lasek --- scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py b/scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py index fd761b6b20c2..2261f70ea928 100644 --- a/scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py @@ -148,7 +148,6 @@ 'precision': 'bf16', 'logger': False, # logger provided by exp_manager 'enable_checkpointing': False, - 'replace_sampler_ddp': False, 'max_epochs': -1, # PTL default. In practice, max_steps will be reached first. 'max_steps': 100000, # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches 'log_every_n_steps': 10, From c6cf2761fb3cffebcc9e8b071ff7895a9c8141a9 Mon Sep 17 00:00:00 2001 From: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com> Date: Fri, 8 Dec 2023 11:37:12 -0800 Subject: [PATCH 11/12] Fixing conversion script to work for code llama (#7997) * Fixing conversion script to work for code llama * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Shanmugam Ramasamy Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py index d6007aa771c0..58abcd20442d 100644 --- a/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py @@ -98,6 +98,9 @@ def load_config(llama_config): nemo_config = OmegaConf.load( os.path.join(os.path.dirname(__file__), '../../examples/nlp/language_modeling/conf/megatron_llama_config.yaml') ).model + + if llama_config.get('rope_theta', None): + nemo_config['rotary_base'] = llama_config['rope_theta'] nemo_config.encoder_seq_length = llama_config['max_position_embeddings'] nemo_config.num_layers = int(llama_config['num_hidden_layers']) nemo_config.hidden_size = llama_config['hidden_size'] From fa8d416793d850f4ce56bea65e1fe28cc0d092c0 Mon Sep 17 00:00:00 2001 From: trias702 <25867060+trias702@users.noreply.github.com> Date: Sat, 9 Dec 2023 17:41:52 -0800 Subject: [PATCH 12/12] Reworked MegatronPretrainingRandomBatchSampler to correctly handle epochs > 1 (#7920) * Initital commit of reworked MegatronPretrainingRandomBatchSampler Signed-off-by: Daniel Egert * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed small length based bug Signed-off-by: Daniel Egert * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Daniel Egert Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper --- .../megatron/megatron_batch_samplers.py | 56 +++++++++++++++---- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py b/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py index 8b06ac951a66..87cb6cb8c8bc 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py @@ -175,6 +175,11 @@ class MegatronPretrainingRandomBatchSampler(BaseMegatronBatchSampler): # are necessary for ViT training. However, to keep this simple, # I omit those two arguments. # commit: https://github.com/NVIDIA/Megatron-LM/commit/7a77abd9b6267dc0020a60b424b4748fc22790bb + # + # NOTE (degert): I have re-written this class somewhat as previous implementation relied on the + # base class constructor which would have thrown in the case of consumed_samples >= total_samples + # which this class was designed to do, as that is how it implicitly calculates the current epoch + # I have also added an explicit seed which allows us to remove Dataset-side shuffling in Nemo-Aligner def __init__( self, total_samples: int, @@ -184,20 +189,47 @@ def __init__( data_parallel_rank: int, data_parallel_size: int, drop_last: bool, + pad_samples_to_global_batch_size: bool = False, + seed: int = 0, ) -> None: - super().__init__( - total_samples=total_samples, - consumed_samples=consumed_samples, - micro_batch_size=micro_batch_size, - global_batch_size=global_batch_size, - data_parallel_rank=data_parallel_rank, - data_parallel_size=data_parallel_size, - drop_last=drop_last, - ) + + # Sanity checks. + if total_samples <= 0: + raise RuntimeError("no sample to consume: {}".format(total_samples)) + if micro_batch_size <= 0: + raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}") + if data_parallel_size <= 0: + raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}") + if data_parallel_rank >= data_parallel_size: + raise RuntimeError( + "data_parallel_rank should be smaller than data size, but {} >= {}".format( + data_parallel_rank, data_parallel_size + ) + ) + + self.total_samples: int = total_samples + self.consumed_samples: int = consumed_samples + self.micro_batch_size: int = micro_batch_size + self.data_parallel_rank: int = data_parallel_rank + self.data_parallel_size: int = data_parallel_size + self.drop_last: bool = drop_last + self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size + self.seed = seed + + self.update_global_batch_size(global_batch_size) self.last_batch_size = self.total_samples % self._global_batch_size - def __len__(self): - num_available_samples = self.total_samples + def __len__(self) -> int: + """Length of Random Batch Sampler. + + ..note:: + When `rampup_batch_size` is enabled, the return value can be not exactly precise. + + """ + active_total_samples = self.total_samples - self.last_batch_size + num_available_samples = ( + active_total_samples * (1 + (self.consumed_samples // active_total_samples)) + ) - self.consumed_samples if self.drop_last: return num_available_samples // self.global_batch_size else: @@ -215,7 +247,7 @@ def __iter__(self): start_idx = self.data_parallel_rank * bucket_size g = torch.Generator() - g.manual_seed(self.epoch) + g.manual_seed(self.seed + self.epoch) random_idx = torch.randperm(bucket_size, generator=g).tolist() idx_range = [start_idx + x for x in random_idx[bucket_offset:]]