Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Parallelism usage from Apex -> Megatron Core #6393

Merged
merged 48 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
23a46cb
import parallel_state and tensor_parallel from megatron.core
ericharper Oct 5, 2022
13b1572
update column parallel async allreduce arg
ericharper Oct 5, 2022
957e1b0
typos
ericharper Oct 5, 2022
3bd627c
merge main into base branch
aklife97 Feb 23, 2023
59e7772
play stash + some changes
aklife97 Feb 23, 2023
571dbc7
make grad scaler callable
ericharper Mar 8, 2023
14721f2
Fixed formatting
SeanNaren Mar 16, 2023
dea8098
Make sure RETRO integrates well with the core (#6207)
yidong72 Mar 17, 2023
3180db5
merge main to base branch
aklife97 Mar 30, 2023
a7c502b
[NLP] Support T5 with Megatron Core (#6222)
SeanNaren Apr 4, 2023
1ff66aa
GPT P-tuning core (max_len pad -> slow)
aklife97 Apr 4, 2023
874286b
add GPT p-tuning w/ global batch based passing
aklife97 Apr 5, 2023
3cb4230
Merge branch 'GPT_integrate_core' of github.com:NVIDIA/NeMo into GPT_…
aklife97 Apr 5, 2023
f92e725
add T5 p-tuning support
aklife97 Apr 6, 2023
eb68b32
add megatron core install to Jenkinsfile
ericharper Apr 7, 2023
d6c9c15
fix command
ericharper Apr 7, 2023
304cc1c
add guard efault for arg
ericharper Apr 7, 2023
9fe410d
shift bert, retro, adapter + other namespace changes
aklife97 Apr 7, 2023
d784b1d
Merge branch 'GPT_integrate_core' of github.com:NVIDIA/NeMo into GPT_…
aklife97 Apr 7, 2023
df1d5d1
build_model merge into one
aklife97 Apr 7, 2023
a3db3aa
Ensure fine-tuning/prompt learning work for T5 (#6385)
SeanNaren Apr 7, 2023
1ba9fa6
rm extra split impl
aklife97 Apr 7, 2023
8330183
fix for CI
aklife97 Apr 7, 2023
0d27220
temp change for tests
aklife97 Apr 8, 2023
f136f89
add bs=1 for log
aklife97 Apr 8, 2023
d063437
fix
aklife97 Apr 10, 2023
b44b145
iter changes NMT
aklife97 Apr 10, 2023
10c668d
NMT partial fix
aklife97 Apr 10, 2023
af19bf2
move on_train_batch_end to base_model
aklife97 Apr 10, 2023
f44235d
rm on_train_batch_end
aklife97 Apr 10, 2023
9b6982d
temp remove NMT test
aklife97 Apr 11, 2023
738c132
add training_step logic for T5 derived dynamic len models
aklife97 Apr 11, 2023
f9585b1
add NMT test back
aklife97 Apr 11, 2023
492cd90
style fix
aklife97 Apr 11, 2023
f068656
change no_async_tensor_model_parallel_allreduce
aklife97 Apr 11, 2023
1c07e05
sequence_parallel_enabled -> sequence_parallel
aklife97 Apr 11, 2023
17e69fc
fix T5 FT batch size
aklife97 Apr 11, 2023
35173a7
seq enabled
aklife97 Apr 11, 2023
644b2f5
T5 sequence length fix
aklife97 Apr 11, 2023
ff94631
NMT mp fork to spawn
aklife97 Apr 11, 2023
47f2e01
make function signatures consistent across models
aklife97 Apr 11, 2023
e2f174e
merge main into branch
aklife97 Apr 11, 2023
7216232
Merge branch 'main' into GPT_integrate_core
aklife97 Apr 11, 2023
71ea0e7
make print log
aklife97 Apr 11, 2023
4675db9
Merge branch 'GPT_integrate_core' of github.com:NVIDIA/NeMo into GPT_…
aklife97 Apr 11, 2023
8ce981c
rm unused import
aklife97 Apr 11, 2023
90f19d6
update Dockerfile to install core
aklife97 Apr 12, 2023
6fb8ee9
keep core path in workspace
aklife97 Apr 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,13 @@ RUN apt-get update && \
libavdevice-dev && \
rm -rf /var/lib/apt/lists/*

WORKDIR /tmp/
WORKDIR /workspace/
# Install Megatron-core
RUN git clone https://github.com/aklife97/Megatron-LM.git && \
cd Megatron-LM && \
pip install -e .

WORKDIR /tmp/
# TODO: Remove once this Apex commit (2/24/23) is included in PyTorch
# container
RUN git clone https://github.com/NVIDIA/apex.git && \
Expand Down
8 changes: 8 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ pipeline {
}
}

// TODO: remove when pip package is available
stage('Megatron Core installation') {
steps {
sh 'git clone https://github.com/aklife97/Megatron-LM.git && \
cd Megatron-LM && \
pip install -e .'
}
}

stage('PyTorch Lightning version') {
steps {
Expand Down
8 changes: 4 additions & 4 deletions examples/nlp/language_modeling/megatron_ckpt_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from argparse import ArgumentParser

import torch
from apex.transformer import parallel_state
from megatron.core import parallel_state
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.trainer import Trainer

Expand Down Expand Up @@ -121,9 +121,9 @@ def convert(local_rank, rank, world_size, args):
app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size

parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=app_state.tensor_model_parallel_size,
pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank_=app_state.pipeline_model_parallel_split_rank,
tensor_model_parallel_size=app_state.tensor_model_parallel_size,
pipeline_model_parallel_size=app_state.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank,
)

app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
Expand Down
8 changes: 5 additions & 3 deletions examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
from nemo.utils.model_utils import inject_model_parallel_rank

try:
from apex.transformer import parallel_state
from megatron.core import parallel_state

HAVE_MEGATRON_CORE = True

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

HAVE_MEGATRON_CORE = False

"""
This is the script to run GPT text generation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch
import torch.multiprocessing as mp
from apex.transformer import parallel_state
from megatron.core import parallel_state
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
Expand Down
1 change: 1 addition & 0 deletions examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from typing import Any, Optional

import torch
from megatron.core import parallel_state
from pytorch_lightning.core.saving import _load_state as ptl_load_state
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml
from pytorch_lightning.trainer.trainer import Trainer
Expand Down
8 changes: 5 additions & 3 deletions examples/nlp/language_modeling/megatron_retro_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@
from nemo.core.config import hydra_runner

try:
from apex.transformer import parallel_state
from megatron.core import parallel_state

HAVE_MEGATRON_CORE = True

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

HAVE_MEGATRON_CORE = False

"""
This is the script to run RETRO Model text generation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from nemo.utils.app_state import AppState

try:
from apex.transformer import parallel_state
from megatron.core import parallel_state

HAVE_APEX = True
HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False
HAVE_MEGATRON_CORE = False


if not torch.cuda.is_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
import torch.multiprocessing as mp
from apex.transformer import parallel_state
from megatron.core import parallel_state
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
import torch.multiprocessing as mp
from apex.transformer import parallel_state
from megatron.core import parallel_state
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
import torch.multiprocessing as mp
from apex.transformer import parallel_state
from megatron.core import parallel_state
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
import torch.multiprocessing as mp
from apex.transformer import parallel_state
from megatron.core import parallel_state
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/machine_translation/megatron_nmt_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.


import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelSummary
Expand All @@ -33,6 +34,8 @@
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager

mp.set_start_method("spawn", force=True)


@hydra_runner(config_path="conf", config_name="aayn_base_megatron")
def main(cfg) -> None:
Expand Down
22 changes: 20 additions & 2 deletions nemo/collections/nlp/data/glue_benchmark/glue_benchmark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def __init__(
max_seq_length_decoder: int = 128,
use_cache: bool = True,
prefix_override: str = None,
pad_to_max_length: bool = True,
):
"""
Processes GLUE datasets
Expand All @@ -392,10 +393,12 @@ def __init__(
max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
use_cache: whether to use data cache
prefix_override: if you want to override default prompt for this task specify this via a string.
pad_to_max_length: If true, pad to the maximum length.
"""
super().__init__(file_name, task_name, tokenizer, max_seq_length, use_cache, compute_features=False)
self.max_seq_length = max_seq_length
self.max_seq_length_decoder = max_seq_length_decoder
self.pad_to_max_length = pad_to_max_length
self.processor = processors[self.task_name]()
self.prefix_override = prefix_override
self.features = self.convert_examples_to_features()
Expand All @@ -412,9 +415,16 @@ def collate_fn(self, batch):
dec_input = [item['text_dec'] for item in batch]
labels = [item['labels'] for item in batch]

max_dec_input_length = max([len(item) for item in dec_input]) if dec_input else 0
max_enc_query_length = max([len(item) for item in enc_query]) if enc_query else 0
max_dec_input_length = max([len(item) for item in dec_input]) if dec_input else 0
max_label_length = max([len(item) for item in labels]) if labels else 0
if self.pad_to_max_length:
assert max_enc_query_length <= self.max_seq_length
assert max_dec_input_length <= self.max_seq_length_decoder
assert max_label_length <= self.max_seq_length_decoder
max_enc_query_length = self.max_seq_length
max_dec_input_length = self.max_seq_length_decoder
max_label_length = self.max_seq_length_decoder

loss_mask = [([1] * (len(item))) + ([0] * (max_label_length - len(item))) for item in labels]
enc_query = [item + [self.tokenizer.pad_id] * (max_enc_query_length - len(item)) for item in enc_query]
Expand Down Expand Up @@ -488,10 +498,18 @@ def __init__(
use_cache: bool = True,
prefix_override: str = None,
lang_list: List[str] = None,
pad_to_max_length: bool = True,
):
self.lang_list = set(lang_list)
super().__init__(
file_name, task_name, tokenizer, max_seq_length, max_seq_length_decoder, use_cache, prefix_override
file_name,
task_name,
tokenizer,
max_seq_length,
max_seq_length_decoder,
use_cache,
prefix_override,
pad_to_max_length,
)
if len(lang_list) <= 0 or lang_list is None:
raise ValueError(f"Found an empty or None lang_list for {self.task_name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@
from nemo.utils.get_rank import is_global_rank_zero

try:
from apex.transformer import parallel_state
from megatron.core import parallel_state

HAVE_APEX = True
HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):

HAVE_APEX = False
HAVE_MEGATRON_CORE = False


DSET_TYPE_BERT = 'standard_bert'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
from nemo.utils import logging

try:
from apex.transformer import parallel_state
from megatron.core import parallel_state

HAVE_APEX = True
HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):

HAVE_APEX = False
HAVE_MEGATRON_CORE = False


def build_dataset(cfg, trainer, data_prefix, data_impl, num_samples, seq_length, seed, skip_warmup, tokenizer, name):
Expand Down Expand Up @@ -303,9 +303,9 @@ def __init__(
seed,
drop_last=True,
):
if not HAVE_APEX:
if not HAVE_MEGATRON_CORE:
raise ImportError(
"Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
"megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)

super().__init__()
Expand Down Expand Up @@ -432,9 +432,9 @@ class MockGPTDataset(Dataset):
def __init__(
self, cfg, tokenizer, name, num_samples, seq_length, seed,
):
if not HAVE_APEX:
if not HAVE_MEGATRON_CORE:
raise ImportError(
"Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
"Megatron core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)

super().__init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def __init__(
micro_batch_size: The size of each micro batch.
global_batch_size: The size of global batch.
data_parallel_rank: The value you can obtain via
`parallel_state.get_data_parallel_rank()` of apex.transformer.
`parallel_state.get_data_parallel_rank()` of megatron.core.
data_parallel_size: The value you can obtain via
`parallel_state.get_data_parallel_world_size()` of apex.transformer.
`parallel_state.get_data_parallel_world_size()` of megatron.core.
"""
# Sanity checks.
if total_samples <= 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@
from nemo.utils import logging

try:
from apex.transformer import parallel_state
from megatron.core import parallel_state

HAVE_MEGATRON_CORE = True

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

HAVE_MEGATRON_CORE = False

__all__ = [
"RETRODataset",
Expand Down Expand Up @@ -76,9 +78,9 @@ def __init__(
knn_index: KNNIndex,
retrieval_index: MMapRetrievalIndexedDataset,
):
if not HAVE_APEX:
if not HAVE_MEGATRON_CORE:
raise ImportError(
"Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
"megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)

super().__init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
try:
from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel

HAVE_APEX = True
HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False
HAVE_MEGATRON_CORE = False

# from nemo.collections.nlp.models.language_modeling.megatron.t5_model import T5Model
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,26 @@
)

try:
from apex.transformer import parallel_state, tensor_parallel
from apex.transformer.enums import AttnMaskType
from apex.transformer.tensor_parallel.layers import set_tensor_model_parallel_attributes

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):

HAVE_APEX = False

# fake missing classes with None attributes
AttnMaskType = ApexGuardDefaults()

try:
from megatron.core import parallel_state, tensor_parallel

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):

HAVE_MEGATRON_CORE = False


def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,26 @@
)

try:
from apex.transformer import parallel_state, tensor_parallel
from apex.transformer.enums import AttnMaskType

HAVE_APEX = True

except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

# fake missing classes with None attributes
AttnMaskType = ApexGuardDefaults()

HAVE_APEX = False

try:
from megatron.core import parallel_state, tensor_parallel

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):

HAVE_MEGATRON_CORE = False


def post_language_model_processing(
lm_output,
Expand Down
Loading