Skip to content

Commit

Permalink
Merge branch 'main' into markelsanz14/disable_act_ckpt
Browse files Browse the repository at this point in the history
Signed-off-by: Markel Sanz Ausin <[email protected]>
  • Loading branch information
markelsanz14 committed Apr 13, 2023
2 parents 936c2e9 + 7854bd4 commit fa0f8d7
Show file tree
Hide file tree
Showing 75 changed files with 2,063 additions and 1,879 deletions.
7 changes: 6 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,13 @@ RUN apt-get update && \
libavdevice-dev && \
rm -rf /var/lib/apt/lists/*

WORKDIR /tmp/
WORKDIR /workspace/
# Install Megatron-core
RUN git clone https://github.com/aklife97/Megatron-LM.git && \
cd Megatron-LM && \
pip install -e .

WORKDIR /tmp/
# TODO: Remove once this Apex commit (2/24/23) is included in PyTorch
# container
RUN git clone https://github.com/NVIDIA/apex.git && \
Expand Down
8 changes: 8 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ pipeline {
}
}

// TODO: remove when pip package is available
stage('Megatron Core installation') {
steps {
sh 'git clone https://github.com/aklife97/Megatron-LM.git && \
cd Megatron-LM && \
pip install -e .'
}
}

stage('PyTorch Lightning version') {
steps {
Expand Down
11 changes: 11 additions & 0 deletions docs/source/core/exp_manager.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,20 @@ and stability. To use EMA, simply set the following via YAML or :class:`~nemo.ut
every_n_steps: 1 # How often to update EMA weights
validate_original_weights: False # Whether to use original weights for validation calculation or EMA weights
Support for Preemption
----------------------

.. _exp_manager_preemption_support-label:

NeMo adds support for a callback upon preemption while running the models on clusters. The callback takes care of saving the current state of training via the ``.ckpt``
file followed by a graceful exit from the run. The checkpoint saved upon preemption has the ``*last.ckpt`` suffix and replaces the previously saved last checkpoints.
This feature is useful to increase utilization on clusters.
The ``PreemptionCallback`` is enabled by default. To disable it simply add ``create_preemption_callback: False`` under exp_manager in the config YAML file.


.. _nemo_multirun-label:


Hydra Multi-Run with NeMo
-------------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ model:
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "fuly_randomized"
bucketing_strategy: "fully_randomized"
bucketing_batch_size: null

validation_ds:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model:
tarred_audio_filepaths: null
shuffle_n: 2048
# bucketing params
bucketing_strategy: "fuly_randomized"
bucketing_strategy: "fully_randomized"
bucketing_batch_size: null

validation_ds:
Expand Down
59 changes: 45 additions & 14 deletions examples/nlp/language_modeling/megatron_change_num_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,19 @@ def split_partition(
# Special case for GPT models - whose last PP TP rank has a duplicate embedding tensor

# Megatron GPT check for final PP rank duplicated embeddings
if (
pp_rank == (pp_size - 1) and hasattr(model, 'model') and hasattr(model.model, 'word_embeddings')
): # duplicate embedding copy (tied weights)
duplicate_word_embedding_offset = 1
else:
duplicate_word_embedding_offset = 0
idx += duplicate_word_embedding_offset # add duplicate embedding offset to index
duplicate_gpt_word_embedding_offset = 0
untied_gpt_embedding = False

if 'gpt' in model.cfg.target.lower():
logging.info("Splitting GPT model")
if pp_rank == (pp_size - 1) and hasattr(model, 'model') and hasattr(model.model, 'word_embeddings'):
# duplicate embedding copy (tied weights)
duplicate_gpt_word_embedding_offset = 1

if model.cfg.get('share_embeddings_and_output_weights', True) is False:
untied_gpt_embedding = True

idx += duplicate_gpt_word_embedding_offset # add duplicate embedding offset to index

# Special case for T5 models - where the embeddings are shared between encoder and decoder
# and the rank of decoder split is arbitrary.
Expand Down Expand Up @@ -334,8 +340,8 @@ def split_partition(

# Print some debug info
logging.info(f"Start Layer Idx: {idx} Number of layers in current rank: {num_params} Offset: {offset}")
if duplicate_word_embedding_offset > 0:
logging.info(f"GPT duplicate_word_embedding_offset: {duplicate_word_embedding_offset}")
if duplicate_gpt_word_embedding_offset > 0:
logging.info(f"GPT duplicate_gpt_word_embedding_offset: {duplicate_gpt_word_embedding_offset}")
if enc_dec_share_token_embeddings_count:
logging.info(f"EncDec share_token_embeddings_count: {enc_dec_share_token_embeddings_count}")
if shared_enc_dec_embeddings_intermediate:
Expand All @@ -356,7 +362,7 @@ def split_partition(
# but GPT has an additional word embedding as its last parameter
# Therefore we check for this, and reset the index to the parameter of the PP 0 TP 0 rank
# which holds the parameters of the embedding.
if idx == (len(partitions[0])) and duplicate_word_embedding_offset > 0:
if idx == (len(partitions[0])) and duplicate_gpt_word_embedding_offset > 0:
logging.info("Found duplicate embedding copy for GPT model, resetting index")
idx = 0 # reset idx parameter to 0 if we have duplicate embedding copy

Expand Down Expand Up @@ -421,7 +427,7 @@ def split_partition(
# Add 1 to offset to account for last PP rank's duplicated Embedding
offset_diff = offset - num_params
# GPT offset correction
if pp_size > 1 and pp_rank == (pp_size - 1) and pp_split_rank == 0:
if not untied_gpt_embedding and pp_size > 1 and pp_rank == (pp_size - 1) and pp_split_rank == 0:
offset_diff += 1
# T5 offset correction for shared embedding when pp split rank == pp rank
if shared_enc_dec_embeddings:
Expand Down Expand Up @@ -473,7 +479,7 @@ def split_partition(

def main():
parser = ArgumentParser()
parser.add_argument("--model_file", type=str, required=True, help="Path to source .nemo file")
parser.add_argument("--model_file", type=str, default=None, required=False, help="Path to source .nemo file")
parser.add_argument("--target_file", type=str, required=True, help="Path to write target .nemo file")
parser.add_argument(
"--tensor_model_parallel_size", type=int, default=-1, required=False, help="TP size of source model"
Expand Down Expand Up @@ -552,6 +558,9 @@ def main():
pipeline_model_parallel_split_rank = args.target_pipeline_model_parallel_split_rank
cls = model_utils.import_class_by_path(args.model_class)

if args.model_file is None and args.model_extracted_dir is None:
raise ValueError("Cannot pass model_file and model_extracted_dir as None at the same time.")

trainer = Trainer(devices=1, strategy=NLPDDPStrategy(), accelerator="cpu", precision=precision)

if tp_size < 0 or pp_size < 0:
Expand Down Expand Up @@ -614,8 +623,13 @@ def main():
logging.info(f"Using extracted model directory: {args.model_extracted_dir}")
save_restore_connector.model_extracted_dir = args.model_extracted_dir

if args.model_file is not None:
model_filepath = args.model_file
else:
model_filepath = args.model_extracted_dir

model = cls.restore_from(
restore_path=args.model_file,
restore_path=model_filepath,
trainer=trainer,
map_location=torch.device("cpu"),
save_restore_connector=save_restore_connector,
Expand Down Expand Up @@ -684,7 +698,24 @@ def main():
else:
# If input model has TP = 1 and PP = 1
app_state.model_parallel_size = 1
model = cls.restore_from(restore_path=args.model_file, trainer=trainer, map_location=torch.device("cpu"))

save_restore_connector = NLPSaveRestoreConnector()

if args.model_extracted_dir is not None:
logging.info(f"Using extracted model directory: {args.model_extracted_dir}")
save_restore_connector.model_extracted_dir = args.model_extracted_dir

if args.model_file is not None:
model_filepath = args.model_file
else:
model_filepath = args.model_extracted_dir

model = cls.restore_from(
restore_path=model_filepath,
trainer=trainer,
map_location=torch.device("cpu"),
save_restore_connector=save_restore_connector,
)
model.to(dtype=dtype)

# If target model has TP > 1 or PP > 1
Expand Down
8 changes: 4 additions & 4 deletions examples/nlp/language_modeling/megatron_ckpt_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from argparse import ArgumentParser

import torch
from apex.transformer import parallel_state
from megatron.core import parallel_state
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.trainer import Trainer

Expand Down Expand Up @@ -121,9 +121,9 @@ def convert(local_rank, rank, world_size, args):
app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size

parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=app_state.tensor_model_parallel_size,
pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank_=app_state.pipeline_model_parallel_split_rank,
tensor_model_parallel_size=app_state.tensor_model_parallel_size,
pipeline_model_parallel_size=app_state.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank,
)

app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
Expand Down
8 changes: 5 additions & 3 deletions examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
from nemo.utils.model_utils import inject_model_parallel_rank

try:
from apex.transformer import parallel_state
from megatron.core import parallel_state

HAVE_MEGATRON_CORE = True

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

HAVE_MEGATRON_CORE = False

"""
This is the script to run GPT text generation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch
import torch.multiprocessing as mp
from apex.transformer import parallel_state
from megatron.core import parallel_state
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
Expand Down
1 change: 1 addition & 0 deletions examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from typing import Any, Optional

import torch
from megatron.core import parallel_state
from pytorch_lightning.core.saving import _load_state as ptl_load_state
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml
from pytorch_lightning.trainer.trainer import Trainer
Expand Down
8 changes: 5 additions & 3 deletions examples/nlp/language_modeling/megatron_retro_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@
from nemo.core.config import hydra_runner

try:
from apex.transformer import parallel_state
from megatron.core import parallel_state

HAVE_MEGATRON_CORE = True

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

HAVE_MEGATRON_CORE = False

"""
This is the script to run RETRO Model text generation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from nemo.utils.app_state import AppState

try:
from apex.transformer import parallel_state
from megatron.core import parallel_state

HAVE_APEX = True
HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False
HAVE_MEGATRON_CORE = False


if not torch.cuda.is_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
import torch.multiprocessing as mp
from apex.transformer import parallel_state
from megatron.core import parallel_state
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
import torch.multiprocessing as mp
from apex.transformer import parallel_state
from megatron.core import parallel_state
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
import torch.multiprocessing as mp
from apex.transformer import parallel_state
from megatron.core import parallel_state
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
import torch.multiprocessing as mp
from apex.transformer import parallel_state
from megatron.core import parallel_state
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/machine_translation/megatron_nmt_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.


import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelSummary
Expand All @@ -33,6 +34,8 @@
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager

mp.set_start_method("spawn", force=True)


@hydra_runner(config_path="conf", config_name="aayn_base_megatron")
def main(cfg) -> None:
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,9 @@ def infer_file(self, path2audio_file):
audio = librosa.core.resample(audio, orig_sr=sr, target_sr=target_sr)
audio_length = audio.shape[0]
device = self.device
audio = np.array(audio)
audio = np.array([audio])
audio_signal, audio_signal_len = (
torch.tensor([audio], device=device),
torch.tensor(audio, device=device),
torch.tensor([audio_length], device=device),
)
mode = self.training
Expand Down
22 changes: 20 additions & 2 deletions nemo/collections/nlp/data/glue_benchmark/glue_benchmark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def __init__(
max_seq_length_decoder: int = 128,
use_cache: bool = True,
prefix_override: str = None,
pad_to_max_length: bool = True,
):
"""
Processes GLUE datasets
Expand All @@ -392,10 +393,12 @@ def __init__(
max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
use_cache: whether to use data cache
prefix_override: if you want to override default prompt for this task specify this via a string.
pad_to_max_length: If true, pad to the maximum length.
"""
super().__init__(file_name, task_name, tokenizer, max_seq_length, use_cache, compute_features=False)
self.max_seq_length = max_seq_length
self.max_seq_length_decoder = max_seq_length_decoder
self.pad_to_max_length = pad_to_max_length
self.processor = processors[self.task_name]()
self.prefix_override = prefix_override
self.features = self.convert_examples_to_features()
Expand All @@ -412,9 +415,16 @@ def collate_fn(self, batch):
dec_input = [item['text_dec'] for item in batch]
labels = [item['labels'] for item in batch]

max_dec_input_length = max([len(item) for item in dec_input]) if dec_input else 0
max_enc_query_length = max([len(item) for item in enc_query]) if enc_query else 0
max_dec_input_length = max([len(item) for item in dec_input]) if dec_input else 0
max_label_length = max([len(item) for item in labels]) if labels else 0
if self.pad_to_max_length:
assert max_enc_query_length <= self.max_seq_length
assert max_dec_input_length <= self.max_seq_length_decoder
assert max_label_length <= self.max_seq_length_decoder
max_enc_query_length = self.max_seq_length
max_dec_input_length = self.max_seq_length_decoder
max_label_length = self.max_seq_length_decoder

loss_mask = [([1] * (len(item))) + ([0] * (max_label_length - len(item))) for item in labels]
enc_query = [item + [self.tokenizer.pad_id] * (max_enc_query_length - len(item)) for item in enc_query]
Expand Down Expand Up @@ -488,10 +498,18 @@ def __init__(
use_cache: bool = True,
prefix_override: str = None,
lang_list: List[str] = None,
pad_to_max_length: bool = True,
):
self.lang_list = set(lang_list)
super().__init__(
file_name, task_name, tokenizer, max_seq_length, max_seq_length_decoder, use_cache, prefix_override
file_name,
task_name,
tokenizer,
max_seq_length,
max_seq_length_decoder,
use_cache,
prefix_override,
pad_to_max_length,
)
if len(lang_list) <= 0 or lang_list is None:
raise ValueError(f"Found an empty or None lang_list for {self.task_name}")
Expand Down
Loading

0 comments on commit fa0f8d7

Please sign in to comment.