Skip to content

Commit

Permalink
Merge branch 'main' into tkonuk/lora
Browse files Browse the repository at this point in the history
  • Loading branch information
arendu authored Dec 11, 2023
2 parents bda31e9 + fa8d416 commit 635b014
Show file tree
Hide file tree
Showing 14 changed files with 96 additions and 38 deletions.
4 changes: 2 additions & 2 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ pipeline {
steps {
sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout e122536b7645edcb7ebf099b5c92a443f7dbf8e7 && \
pip install -e .'
git checkout 973330e9c3681604703bf1eb6b5a265d1b9b9b38 && \
pip install .'
}
}

Expand Down
16 changes: 16 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@
**NVIDIA NeMo**
===============

Latest News
-----------

- 2023/12/06 `New NVIDIA NeMo Framework Features and NVIDIA H200 <https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility/>`_

.. image:: https://github.com/sbhavani/TransformerEngine/blob/main/docs/examples/H200-NeMo-performance.png
:target: https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility
:alt: H200-NeMo-performance
:width: 600

NeMo Framework has been updated with state-of-the-art features,
such as FSDP, Mixture-of-Experts, and RLHF with TensorRT-LLM to provide speedups up to 4.2x for Llama-2 pre-training on H200.
**All of these features will be available in an upcoming release.**



Introduction
------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ model:
apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
sync_batch_comm: False # Enable stream synchronization after each p2p communication between pipeline stages
nccl_communicator_config_path: null # Path to the yaml file with NCCL communicator options (min_ctas, max_ctas, and cga_cluster_size)

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/language_modeling/tuning/megatron_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
if cfg.model.get('seq_len_interpolation_factor', None) is not None:
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor

if cfg.model.get('rotary_base', None) is not None:
gpt_cfg.rotary_base = cfg.model.rotary_base

sft_cls = MegatronGPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ class MegatronPretrainingRandomBatchSampler(BaseMegatronBatchSampler):
# are necessary for ViT training. However, to keep this simple,
# I omit those two arguments.
# commit: https://github.com/NVIDIA/Megatron-LM/commit/7a77abd9b6267dc0020a60b424b4748fc22790bb
#
# NOTE (degert): I have re-written this class somewhat as previous implementation relied on the
# base class constructor which would have thrown in the case of consumed_samples >= total_samples
# which this class was designed to do, as that is how it implicitly calculates the current epoch
# I have also added an explicit seed which allows us to remove Dataset-side shuffling in Nemo-Aligner
def __init__(
self,
total_samples: int,
Expand All @@ -184,20 +189,47 @@ def __init__(
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool,
pad_samples_to_global_batch_size: bool = False,
seed: int = 0,
) -> None:
super().__init__(
total_samples=total_samples,
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
data_parallel_rank=data_parallel_rank,
data_parallel_size=data_parallel_size,
drop_last=drop_last,
)

# Sanity checks.
if total_samples <= 0:
raise RuntimeError("no sample to consume: {}".format(total_samples))
if micro_batch_size <= 0:
raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise RuntimeError(
"data_parallel_rank should be smaller than data size, but {} >= {}".format(
data_parallel_rank, data_parallel_size
)
)

self.total_samples: int = total_samples
self.consumed_samples: int = consumed_samples
self.micro_batch_size: int = micro_batch_size
self.data_parallel_rank: int = data_parallel_rank
self.data_parallel_size: int = data_parallel_size
self.drop_last: bool = drop_last
self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size
self.seed = seed

self.update_global_batch_size(global_batch_size)
self.last_batch_size = self.total_samples % self._global_batch_size

def __len__(self):
num_available_samples = self.total_samples
def __len__(self) -> int:
"""Length of Random Batch Sampler.
..note::
When `rampup_batch_size` is enabled, the return value can be not exactly precise.
"""
active_total_samples = self.total_samples - self.last_batch_size
num_available_samples = (
active_total_samples * (1 + (self.consumed_samples // active_total_samples))
) - self.consumed_samples
if self.drop_last:
return num_available_samples // self.global_batch_size
else:
Expand All @@ -215,7 +247,7 @@ def __iter__(self):
start_idx = self.data_parallel_rank * bucket_size

g = torch.Generator()
g.manual_seed(self.epoch)
g.manual_seed(self.seed + self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
self.tokenizer = None

with open_dict(cfg):
if cfg.get('precision', None) is None and trainer is not None:
if cfg.get('precision', None) is None:
cfg.precision = trainer.precision

super().__init__(cfg, trainer=trainer, no_lm_init=no_lm_init)
Expand Down Expand Up @@ -773,7 +773,6 @@ def build_model_parallel_config(self) -> ModelParallelConfig:
cfg = OmegaConf.to_container(self.cfg, resolve=True)

# map precision related configs
precision = cfg.get('precision', 32) # PTL trainer precision
megatron_amp_O2 = cfg.get('megatron_amp_O2', False)

# dtype used in p2p communication
Expand All @@ -791,7 +790,7 @@ def build_model_parallel_config(self) -> ModelParallelConfig:
and not self.cfg.get('sequence_parallel', False),
"pipeline_dtype": pipeline_dtype,
"grad_scale_func": self.trainer.precision_plugin.scaler.scale
if self.torch_dtype == torch.float16
if self.trainer.precision in ["16", "16-mixed"]
else None,
"enable_autocast": not megatron_amp_O2 and self.torch_dtype in [torch.bfloat16, torch.float16],
"autocast_dtype": self.autocast_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def model_provider_func(self, pre_process, post_process):
position_embedding_type=self.cfg.get('position_embedding_type', 'learned_absolute'),
rotary_percent=self.cfg.get('rotary_percentage', 1.0),
seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None),
rotary_base=self.cfg.get('rotary_base', 10000),
)
else:
assert self.cfg.get('num_query_groups', None) is None or self.cfg.get(
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _training_strategy(self) -> NLPDDPStrategy:
no_ddp_communication_hook=True,
gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
nccl_communicator_config_path=self.cfg.model.get('nccl_communicator_config_path', None),
)

def _grad_scaler(self) -> GradScaler:
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, *args, **kwargs):
self.use_ptuning_only = False
super().__init__(*args, **kwargs)
if hasattr(self, "enc_dec_model"):
self.model_prefix = "enc_dec_model." # for T5
self.model_prefix = "enc_dec_model.module." if self.cfg.megatron_amp_O2 else "enc_dec_model." # for T5
else:
self.model_prefix = "model.module." if self.cfg.megatron_amp_O2 else "model."

Expand Down Expand Up @@ -351,7 +351,7 @@ def sharded_state_dict(self, prefix: str = ''):
if not use_mcore_gpt or (self.use_peft and self.setup_complete):
return None
else:
return self.model.sharded_state_dict(prefix=self.model_prefix)
return super().sharded_state_dict(prefix=prefix)

def load_state_dict(self, state_dict, strict: bool = True):
if len(state_dict) == 0:
Expand Down
5 changes: 4 additions & 1 deletion nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class NLPDDPStrategy(DDPStrategy):
Args:
no_ddp_communication_hook: Disable DDP communication hook when using AMP-O2
with FP32 gradient accumulation.
nccl_communicator_config_path: Path to the yaml file with NCCL communicator options
"""

def __init__(
Expand All @@ -89,6 +90,7 @@ def __init__(
cluster_environment: ClusterEnvironment = None,
checkpoint_io: Optional[CheckpointIO] = None,
no_ddp_communication_hook: bool = False,
nccl_communicator_config_path: Optional[str] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
if not HAVE_APEX:
Expand All @@ -103,6 +105,7 @@ def __init__(
super().__init__(parallel_devices, cluster_environment, checkpoint_io, **kwargs)

self.no_ddp_communication_hook = no_ddp_communication_hook
self.nccl_communicator_config_path = nccl_communicator_config_path

def setup(self, trainer: "pl.Trainer") -> None:
"""
Expand Down Expand Up @@ -180,7 +183,6 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None:
Args:
global_rank (int): the global process index.
world_size (int): the total number of GPUs, num_nodes * num_devices
is_slurm_managing_tasks (bool, optional): is the cluster managed by SLURM.
"""
app_state = AppState()

Expand All @@ -196,6 +198,7 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None:
pipeline_model_parallel_size=app_state.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size,
pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank,
nccl_communicator_config_path=self.nccl_communicator_config_path,
)

# assert that fake tp and pp rank match after model parallel init
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/nlp/parts/utils_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ['list2str', 'tensor2list', 'plot_confusion_matrix', 'get_classification_report']
__all__ = [
'torch_dtype_from_precision',
'list2str',
'tensor2list',
'plot_confusion_matrix',
'get_classification_report',
]

import os
import time
Expand Down
25 changes: 11 additions & 14 deletions scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
Example to run this conversion script:
python convert_hf_llama_to_nemo.py \
--in-file <path_to_hf_checkpoints_folder> \
--out-file <path_to_output_nemo_file> \
[--fast-swiglu\
--out-file <path_to_output_nemo_file>
"""

import os
Expand All @@ -41,6 +40,7 @@
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.utils import logging


Expand All @@ -50,7 +50,7 @@ def get_args():
"--in-file", type=str, default=None, required=True, help="Path to Huggingface LLaMA checkpoints",
)
parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output .nemo file.")
parser.add_argument("--precision", type=str, default="32", help="Model precision")
parser.add_argument("--precision", type=str, default="16", help="Model precision")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -94,10 +94,13 @@ def load_model(cls, checkpoint, strict, **kwargs):
return model


def load_config(args, llama_config):
def load_config(llama_config):
nemo_config = OmegaConf.load(
os.path.join(os.path.dirname(__file__), '../../examples/nlp/language_modeling/conf/megatron_llama_config.yaml')
).model

if llama_config.get('rope_theta', None):
nemo_config['rotary_base'] = llama_config['rope_theta']
nemo_config.encoder_seq_length = llama_config['max_position_embeddings']
nemo_config.num_layers = int(llama_config['num_hidden_layers'])
nemo_config.hidden_size = llama_config['hidden_size']
Expand All @@ -116,6 +119,8 @@ def load_config(args, llama_config):
nemo_config['seq_len_interpolation_factor'] = llama_config['rope_scaling']['factor']
else:
raise ValueError("Only linear rope scaling type is supported now")
if llama_config['rope_theta'] is not None:
nemo_config['rotary_base'] = llama_config['rope_theta']

base = 128
while llama_config['vocab_size'] % base != 0:
Expand All @@ -136,7 +141,7 @@ def convert(args):
for name, param in model.named_parameters():
print(f"- {name}")

nemo_config = load_config(args, hf_config)
nemo_config = load_config(hf_config)

if args.precision in ["32", "16"]:
precision = int(float(args.precision))
Expand Down Expand Up @@ -168,15 +173,6 @@ def convert(args):
else:
plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler))

if precision == 32:
dtype = torch.float32
elif precision in [16, "16", "16-mixed"]:
dtype = torch.float16
elif precision in ["bf16", "bf16-mixed"]:
dtype = torch.bfloat16
else:
dtype = torch.float32 # fallback

nemo_config.precision = precision
print(f"nemo_config: {nemo_config}")

Expand Down Expand Up @@ -313,6 +309,7 @@ def convert(args):
model._save_restore_connector = NLPSaveRestoreConnector()

# cast to target precision and disable cpu init
dtype = torch_dtype_from_precision(precision)
model = model.to(dtype=dtype)
model.cfg.use_cpu_initialization = False

Expand Down
1 change: 0 additions & 1 deletion scripts/nlp_language_modeling/convert_mpt_7b_hf_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@
'precision': 'bf16',
'logger': False, # logger provided by exp_manager
'enable_checkpointing': False,
'replace_sampler_ddp': False,
'max_epochs': -1, # PTL default. In practice, max_steps will be reached first.
'max_steps': 100000, # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
'log_every_n_steps': 10,
Expand Down
4 changes: 2 additions & 2 deletions tutorials/asr/ASR_with_NeMo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@
"plt.title('Waveform of Audio Example')\n",
"plt.ylabel('Amplitude')\n",
"\n",
"_ = librosa.display.waveshow(audio)"
"_ = librosa.display.waveshow(audio, color='blue')"
],
"execution_count": null,
"outputs": []
Expand Down Expand Up @@ -330,7 +330,7 @@
},
"source": [
"# Plot the mel spectrogram of our sample\n",
"mel_spec = librosa.feature.melspectrogram(audio, sr=sample_rate)\n",
"mel_spec = librosa.feature.melspectrogram(y=audio, sr=sample_rate)\n",
"mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)\n",
"\n",
"librosa.display.specshow(\n",
Expand Down

0 comments on commit 635b014

Please sign in to comment.