Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start using ModelParallelConfig from Megatron Core #6885

Merged
merged 57 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
63c127d
start adding gpt from megatron core path
ericharper Jun 7, 2023
16d85c4
set model parallel config
ericharper Jun 9, 2023
19e1420
use model parallel config object
ericharper Jun 19, 2023
a309c0b
update args
ericharper Jun 22, 2023
2ea9285
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2023
46ec121
set vp size to none if it is 1
ericharper Jun 23, 2023
575ef8a
set vp size to none if it is 1
ericharper Jun 23, 2023
a8b177c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 22, 2023
a296be8
add TransformerConfig
ericharper Jun 28, 2023
ec3c170
start updating to TransformerConfig
ericharper Jun 28, 2023
e2090ae
add todo
ericharper Jun 29, 2023
e1f38d8
revert to model parallel config
ericharper Jun 30, 2023
cbfb0d4
add hidden_size to model_parallel_config
ericharper Jun 30, 2023
cbf5036
remove imports
ericharper Jun 30, 2023
c6fe7ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2023
2bd408c
remove import
ericharper Jun 30, 2023
06064bf
small clean up
ericharper Jun 30, 2023
d8e9f4f
update hidden size in peft base model, add mcore commit to jenkins
ericharper Jun 30, 2023
afdf3f0
update module args
ericharper Jul 1, 2023
90e8160
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2023
3f194ac
add config obj to flash attention tests
ericharper Jul 5, 2023
95a3b68
remove args
ericharper Jul 5, 2023
a0e133f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2023
e2cc69a
remove sequence parallel arg
ericharper Jul 5, 2023
b4b5b06
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2023
bcbd787
update args
ericharper Jul 7, 2023
f31ff9f
add config to self
ericharper Jul 7, 2023
dd59c1c
update args
ericharper Jul 7, 2023
a6d124e
update args
ericharper Jul 7, 2023
e36d8c1
update args
ericharper Jul 7, 2023
215b1e5
add config to test
ericharper Jul 10, 2023
61adaab
get hidden_size from config
ericharper Jul 10, 2023
8de4993
add try except
ericharper Jul 10, 2023
48942cc
use default
ericharper Jul 10, 2023
7a4e884
update config with hidden size
ericharper Jul 11, 2023
0c008fd
remove arg
ericharper Jul 11, 2023
148e19a
comment out jenkins test
ericharper Jul 13, 2023
6d6a69b
revert import
ericharper Jul 13, 2023
bcc6072
remove optimizer_idx
ericharper Aug 8, 2023
aa5b5fb
prefetch num microbatches
ericharper Aug 8, 2023
2db6c2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2023
e5d48ae
Merge branch 'main' into mcore_gpt_path
ericharper Aug 8, 2023
6199567
Merge branch 'main' into mcore_gpt_path
ericharper Aug 10, 2023
4ddc99f
remove import
ericharper Aug 10, 2023
4551eb6
temporarily comment jenkins test
ericharper Aug 10, 2023
072dcad
pull main
ericharper Aug 11, 2023
a5f24e6
update seq_length
ericharper Aug 11, 2023
8f2e8fb
remove commented code
ericharper Aug 11, 2023
f07af89
update arg
ericharper Aug 12, 2023
c36054c
resolve conflict
ericharper Aug 12, 2023
7dcf6b7
update mbs and gbs of test
ericharper Aug 12, 2023
7519e0f
update batch size in test
ericharper Aug 13, 2023
7aa1188
fix precision in test
ericharper Aug 13, 2023
ceca1f3
update precision
ericharper Aug 13, 2023
82a55f5
move hidden_size out of conditional
ericharper Aug 14, 2023
0a702a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 14, 2023
e4ba515
Merge branch 'main' into mcore_gpt_path
ericharper Aug 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pipeline {
agent {
docker {
image 'nvcr.io/nvidia/pytorch:23.06-py3'
args '--device=/dev/nvidia0 --gpus all --user 0:128 -v /home/TestData:/home/TestData -v $HOME/.cache:/root/.cache --shm-size=8g --env TRANSFORMERS_OFFLINE=1'
args '--device=/dev/nvidia0 --gpus all --user 0:128 -v /home/TestData:/home/TestData -v $HOME/.cache:/root/.cache --shm-size=8g --env TRANSFORMERS_OFFLINE=1 --env HYDRA_FULL_ERROR=1'
}
}
options {
Expand Down Expand Up @@ -59,10 +59,10 @@ pipeline {

stage('Megatron Core installation') {
steps {
// commit points to core 23.05 ToT
// commit points to core_transformer merge
sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout 060415572f4365a2e895f8036c4e37dad0efbdf5 && \
git checkout 3316e811cc5335ee24c2d203416d864edcf2f7a8 && \
pip install -e .'
}
}
Expand Down Expand Up @@ -164,19 +164,21 @@ pipeline {
}
}

stage('L2: Speech Pre-training - Wav2Vec') {
steps {
sh 'python examples/asr/speech_pretraining/speech_pre_training.py \
--config-path="../conf/ssl/wav2vec/" --config-name="wav2vec_ci" \
model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
trainer.devices=[1] \
trainer.accelerator="gpu" \
+trainer.fast_dev_run=True \
exp_manager.exp_dir=examples/asr/speech_pre_training_results'
sh 'rm -rf examples/asr/speech_pre_training_results'
}
}
// TODO: Please Fix Me
// Error locating target 'nemo.collections.asr.modules.wav2vec_modules.ConvFeatureEncoder', see chained exception above.
// stage('L2: Speech Pre-training - Wav2Vec') {
// steps {
// sh 'python examples/asr/speech_pretraining/speech_pre_training.py \
// --config-path="../conf/ssl/wav2vec/" --config-name="wav2vec_ci" \
// model.train_ds.manifest_filepath=/home/TestData/an4_dataset/an4_train.json \
// model.validation_ds.manifest_filepath=/home/TestData/an4_dataset/an4_val.json \
// trainer.devices=[1] \
// trainer.accelerator="gpu" \
// +trainer.fast_dev_run=True \
// exp_manager.exp_dir=examples/asr/speech_pre_training_results'
// sh 'rm -rf examples/asr/speech_pre_training_results'
// }
// }

stage('L2: Speech to Text WPE - Conformer') {
steps {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
AttnMaskType = ApexGuardDefaults()

try:
from megatron.core import parallel_state, tensor_parallel
from megatron.core import ModelParallelConfig, parallel_state, tensor_parallel

HAVE_MEGATRON_CORE = True

Expand Down Expand Up @@ -82,22 +82,22 @@ class BertLMHead(MegatronModule):

def __init__(
self,
config: ModelParallelConfig,
mpu_vocab_size,
hidden_size,
init_method,
layernorm_epsilon,
parallel_output,
use_openai_gelu,
onnx_safe,
sequence_parallel=False,
):

super(BertLMHead, self).__init__()
super(BertLMHead, self).__init__(config=config)

self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output
self.sequence_parallel = sequence_parallel
self.sequence_parallel = config.sequence_parallel

self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = get_layer_norm(hidden_size, eps=layernorm_epsilon)
Expand All @@ -111,7 +111,7 @@ def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
async_tensor_model_parallel_allreduce = parallel_state.get_tensor_model_parallel_world_size() > 1
async_tensor_model_parallel_allreduce = self.config.async_tensor_model_parallel_allreduce
output = parallel_lm_logits(
hidden_states,
word_embeddings_weight,
Expand Down Expand Up @@ -157,6 +157,7 @@ class BertModel(MegatronModule):

def __init__(
self,
config: ModelParallelConfig,
vocab_size,
hidden_size,
max_position_embeddings,
Expand All @@ -171,7 +172,6 @@ def __init__(
post_process=True,
init_method_std=0.02,
fp16_lm_cross_entropy=False,
use_cpu_initialization=False,
megatron_amp_O2=False,
hidden_dropout=0.1,
precision=16,
Expand All @@ -190,8 +190,7 @@ def __init__(
sequence_parallel=False,
position_embedding_type='learned_absolute',
):
super(BertModel, self).__init__()
# args = get_args()
super(BertModel, self).__init__(config=config)
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
Expand All @@ -203,6 +202,7 @@ def __init__(
scaled_init_method = scaled_init_method_normal(init_method_std, num_layers)

self.language_model, self._language_model_key = get_language_model(
config=config,
vocab_size=vocab_size,
hidden_size=hidden_size,
hidden_dropout=hidden_dropout,
Expand All @@ -220,7 +220,6 @@ def __init__(
pre_process=self.pre_process,
post_process=self.post_process,
init_method_std=init_method_std,
use_cpu_initialization=use_cpu_initialization,
megatron_amp_O2=megatron_amp_O2,
precision=precision,
fp32_residual_connection=fp32_residual_connection,
Expand All @@ -234,7 +233,6 @@ def __init__(
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
megatron_legacy=megatron_legacy,
sequence_parallel=sequence_parallel,
position_embedding_type=position_embedding_type,
)

Expand All @@ -244,14 +242,14 @@ def __init__(

if self.post_process:
self.lm_head = BertLMHead(
config,
self.word_embeddings_weight().size(0),
hidden_size,
init_method,
layernorm_epsilon,
parallel_output,
openai_gelu,
onnx_safe,
sequence_parallel,
)
self._lm_head_key = 'lm_head'
self.binary_head = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
HAVE_APEX = False

try:
from megatron.core import parallel_state, tensor_parallel
from megatron.core import ModelParallelConfig, parallel_state, tensor_parallel

HAVE_MEGATRON_CORE = True

Expand Down Expand Up @@ -108,6 +108,7 @@ class GPTModel(MegatronModule):

def __init__(
self,
config: ModelParallelConfig,
vocab_size,
hidden_size,
max_position_embeddings,
Expand All @@ -123,7 +124,6 @@ def __init__(
init_method_std=0.02,
use_scaled_init_method=True,
fp16_lm_cross_entropy=False,
use_cpu_initialization=False,
megatron_amp_O2=False,
hidden_dropout=0.1,
attention_dropout=0.1,
Expand All @@ -148,12 +148,10 @@ def __init__(
rotary_percentage=1.0,
attention_type='multihead',
share_embeddings_and_output_weights=True,
gradient_accumulation_fusion=False,
persist_layer_norm=False,
openai_gelu=False,
megatron_legacy=False,
onnx_safe=False,
sequence_parallel=False,
transformer_engine=False,
fp8=False,
fp8_e4m3=False,
Expand All @@ -168,14 +166,13 @@ def __init__(
use_flash_attention=False,
seq_len_interpolation_factor=None,
):
super(GPTModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)
super(GPTModel, self).__init__(config=config, share_token_embeddings=share_embeddings_and_output_weights)

self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.sequence_parallel = sequence_parallel
self.gradient_accumulation_fusion = gradient_accumulation_fusion
self.sequence_parallel = self.config.sequence_parallel
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.dtype = utils_funcs.dtype_from_precision(precision, megatron_amp_O2)

Expand All @@ -191,6 +188,7 @@ def __init__(
else init_method_normal(init_method_std)
)
self.language_model, self._language_model_key = get_language_model(
config=config,
vocab_size=vocab_size,
hidden_size=hidden_size,
hidden_dropout=hidden_dropout,
Expand All @@ -210,7 +208,6 @@ def __init__(
pre_process=self.pre_process,
post_process=self.post_process,
init_method_std=init_method_std,
use_cpu_initialization=use_cpu_initialization,
megatron_amp_O2=megatron_amp_O2,
precision=precision,
fp32_residual_connection=fp32_residual_connection,
Expand All @@ -226,7 +223,6 @@ def __init__(
bias_activation_fusion=bias_activation_fusion,
bias_dropout_add_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
gradient_accumulation_fusion=gradient_accumulation_fusion,
activation=activation,
headscale=headscale,
transformer_block_type=transformer_block_type,
Expand All @@ -237,7 +233,6 @@ def __init__(
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
megatron_legacy=megatron_legacy,
sequence_parallel=sequence_parallel,
transformer_engine=transformer_engine,
fp8=fp8,
fp8_e4m3=fp8_e4m3,
Expand Down Expand Up @@ -309,7 +304,7 @@ def forward(
self.fp16_lm_cross_entropy,
return_logits=encoder_input is not None,
sequence_parallel=self.sequence_parallel,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
gradient_accumulation_fusion=self.config.gradient_accumulation_fusion,
)
else:
return lm_output
Expand Down
Loading
Loading