-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NLP] Support T5 with Megatron Core #6222
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
SeanNaren
commented
Mar 16, 2023
nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py
Outdated
Show resolved
Hide resolved
Seems formatting on the I've pushed formatting changes directly to the |
SeanNaren
force-pushed
the
GPT_integrate_core_t5
branch
from
March 16, 2023 12:36
89c96ed
to
5ddf524
Compare
SeanNaren
force-pushed
the
GPT_integrate_core_t5
branch
from
March 22, 2023 10:54
6d1dc59
to
f2451f8
Compare
aklife97
reviewed
Mar 30, 2023
nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py
Show resolved
Hide resolved
aklife97
reviewed
Mar 30, 2023
nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py
Outdated
Show resolved
Hide resolved
Signed-off-by: SeanNaren <[email protected]>
Signed-off-by: SeanNaren <[email protected]>
Signed-off-by: SeanNaren <[email protected]>
Signed-off-by: SeanNaren <[email protected]>
Signed-off-by: SeanNaren <[email protected]>
Signed-off-by: SeanNaren <[email protected]>
Signed-off-by: SeanNaren <[email protected]>
SeanNaren
force-pushed
the
GPT_integrate_core_t5
branch
from
April 4, 2023 21:09
6f74d0d
to
ef2d18e
Compare
ericharper
added a commit
that referenced
this pull request
Apr 13, 2023
* import parallel_state and tensor_parallel from megatron.core Signed-off-by: ericharper <[email protected]> * update column parallel async allreduce arg Signed-off-by: ericharper <[email protected]> * typos Signed-off-by: ericharper <[email protected]> * play stash + some changes Signed-off-by: Abhinav Khattar <[email protected]> * make grad scaler callable Signed-off-by: ericharper <[email protected]> * Fixed formatting Signed-off-by: SeanNaren <[email protected]> * Make sure RETRO integrates well with the core (#6207) * fix tests Signed-off-by: Yi Dong <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Yi Dong <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [NLP] Support T5 with Megatron Core (#6222) * Support T5 with Megatron Core Signed-off-by: SeanNaren <[email protected]> * Remove comment Signed-off-by: SeanNaren <[email protected]> * Update prediction step Signed-off-by: SeanNaren <[email protected]> * Further changes to fix fine-tuning Signed-off-by: SeanNaren <[email protected]> * Bug fixes from runs Signed-off-by: SeanNaren <[email protected]> * Revert changes to batch sampler, swap to pretrained sampler Signed-off-by: SeanNaren <[email protected]> * Address feedback Signed-off-by: SeanNaren <[email protected]> --------- Signed-off-by: SeanNaren <[email protected]> * GPT P-tuning core (max_len pad -> slow) Signed-off-by: Abhinav Khattar <[email protected]> * add GPT p-tuning w/ global batch based passing Signed-off-by: Abhinav Khattar <[email protected]> * add T5 p-tuning support Signed-off-by: Abhinav Khattar <[email protected]> * add megatron core install to Jenkinsfile Signed-off-by: ericharper <[email protected]> * fix command Signed-off-by: ericharper <[email protected]> * add guard efault for arg Signed-off-by: ericharper <[email protected]> * shift bert, retro, adapter + other namespace changes Signed-off-by: Abhinav Khattar <[email protected]> * build_model merge into one Signed-off-by: Abhinav Khattar <[email protected]> * Ensure fine-tuning/prompt learning work for T5 (#6385) Signed-off-by: SeanNaren <[email protected]> * rm extra split impl Signed-off-by: Abhinav Khattar <[email protected]> * fix for CI Signed-off-by: Abhinav Khattar <[email protected]> * temp change for tests Signed-off-by: Abhinav Khattar <[email protected]> * add bs=1 for log Signed-off-by: Abhinav Khattar <[email protected]> * fix Signed-off-by: Abhinav Khattar <[email protected]> * iter changes NMT Signed-off-by: Abhinav Khattar <[email protected]> * NMT partial fix Signed-off-by: Abhinav Khattar <[email protected]> * move on_train_batch_end to base_model Signed-off-by: Abhinav Khattar <[email protected]> * rm on_train_batch_end Signed-off-by: Abhinav Khattar <[email protected]> * temp remove NMT test Signed-off-by: Abhinav Khattar <[email protected]> * add training_step logic for T5 derived dynamic len models Signed-off-by: Abhinav Khattar <[email protected]> * add NMT test back Signed-off-by: Abhinav Khattar <[email protected]> * style fix Signed-off-by: Abhinav Khattar <[email protected]> * change no_async_tensor_model_parallel_allreduce Signed-off-by: Abhinav Khattar <[email protected]> * sequence_parallel_enabled -> sequence_parallel Signed-off-by: Abhinav Khattar <[email protected]> * fix T5 FT batch size Signed-off-by: Abhinav Khattar <[email protected]> * seq enabled Signed-off-by: Abhinav Khattar <[email protected]> * T5 sequence length fix Signed-off-by: Abhinav Khattar <[email protected]> * NMT mp fork to spawn Signed-off-by: Abhinav Khattar <[email protected]> * make function signatures consistent across models Signed-off-by: Abhinav Khattar <[email protected]> * make print log Signed-off-by: Abhinav Khattar <[email protected]> * rm unused import Signed-off-by: Abhinav Khattar <[email protected]> * update Dockerfile to install core Signed-off-by: Abhinav Khattar <[email protected]> * keep core path in workspace Signed-off-by: Abhinav Khattar <[email protected]> --------- Signed-off-by: ericharper <[email protected]> Signed-off-by: Abhinav Khattar <[email protected]> Signed-off-by: SeanNaren <[email protected]> Signed-off-by: Yi Dong <[email protected]> Co-authored-by: ericharper <[email protected]> Co-authored-by: SeanNaren <[email protected]> Co-authored-by: Yi Dong <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
hsiehjackson
pushed a commit
to hsiehjackson/NeMo
that referenced
this pull request
Jun 2, 2023
* import parallel_state and tensor_parallel from megatron.core Signed-off-by: ericharper <[email protected]> * update column parallel async allreduce arg Signed-off-by: ericharper <[email protected]> * typos Signed-off-by: ericharper <[email protected]> * play stash + some changes Signed-off-by: Abhinav Khattar <[email protected]> * make grad scaler callable Signed-off-by: ericharper <[email protected]> * Fixed formatting Signed-off-by: SeanNaren <[email protected]> * Make sure RETRO integrates well with the core (NVIDIA#6207) * fix tests Signed-off-by: Yi Dong <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Yi Dong <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [NLP] Support T5 with Megatron Core (NVIDIA#6222) * Support T5 with Megatron Core Signed-off-by: SeanNaren <[email protected]> * Remove comment Signed-off-by: SeanNaren <[email protected]> * Update prediction step Signed-off-by: SeanNaren <[email protected]> * Further changes to fix fine-tuning Signed-off-by: SeanNaren <[email protected]> * Bug fixes from runs Signed-off-by: SeanNaren <[email protected]> * Revert changes to batch sampler, swap to pretrained sampler Signed-off-by: SeanNaren <[email protected]> * Address feedback Signed-off-by: SeanNaren <[email protected]> --------- Signed-off-by: SeanNaren <[email protected]> * GPT P-tuning core (max_len pad -> slow) Signed-off-by: Abhinav Khattar <[email protected]> * add GPT p-tuning w/ global batch based passing Signed-off-by: Abhinav Khattar <[email protected]> * add T5 p-tuning support Signed-off-by: Abhinav Khattar <[email protected]> * add megatron core install to Jenkinsfile Signed-off-by: ericharper <[email protected]> * fix command Signed-off-by: ericharper <[email protected]> * add guard efault for arg Signed-off-by: ericharper <[email protected]> * shift bert, retro, adapter + other namespace changes Signed-off-by: Abhinav Khattar <[email protected]> * build_model merge into one Signed-off-by: Abhinav Khattar <[email protected]> * Ensure fine-tuning/prompt learning work for T5 (NVIDIA#6385) Signed-off-by: SeanNaren <[email protected]> * rm extra split impl Signed-off-by: Abhinav Khattar <[email protected]> * fix for CI Signed-off-by: Abhinav Khattar <[email protected]> * temp change for tests Signed-off-by: Abhinav Khattar <[email protected]> * add bs=1 for log Signed-off-by: Abhinav Khattar <[email protected]> * fix Signed-off-by: Abhinav Khattar <[email protected]> * iter changes NMT Signed-off-by: Abhinav Khattar <[email protected]> * NMT partial fix Signed-off-by: Abhinav Khattar <[email protected]> * move on_train_batch_end to base_model Signed-off-by: Abhinav Khattar <[email protected]> * rm on_train_batch_end Signed-off-by: Abhinav Khattar <[email protected]> * temp remove NMT test Signed-off-by: Abhinav Khattar <[email protected]> * add training_step logic for T5 derived dynamic len models Signed-off-by: Abhinav Khattar <[email protected]> * add NMT test back Signed-off-by: Abhinav Khattar <[email protected]> * style fix Signed-off-by: Abhinav Khattar <[email protected]> * change no_async_tensor_model_parallel_allreduce Signed-off-by: Abhinav Khattar <[email protected]> * sequence_parallel_enabled -> sequence_parallel Signed-off-by: Abhinav Khattar <[email protected]> * fix T5 FT batch size Signed-off-by: Abhinav Khattar <[email protected]> * seq enabled Signed-off-by: Abhinav Khattar <[email protected]> * T5 sequence length fix Signed-off-by: Abhinav Khattar <[email protected]> * NMT mp fork to spawn Signed-off-by: Abhinav Khattar <[email protected]> * make function signatures consistent across models Signed-off-by: Abhinav Khattar <[email protected]> * make print log Signed-off-by: Abhinav Khattar <[email protected]> * rm unused import Signed-off-by: Abhinav Khattar <[email protected]> * update Dockerfile to install core Signed-off-by: Abhinav Khattar <[email protected]> * keep core path in workspace Signed-off-by: Abhinav Khattar <[email protected]> --------- Signed-off-by: ericharper <[email protected]> Signed-off-by: Abhinav Khattar <[email protected]> Signed-off-by: SeanNaren <[email protected]> Signed-off-by: Yi Dong <[email protected]> Co-authored-by: ericharper <[email protected]> Co-authored-by: SeanNaren <[email protected]> Co-authored-by: Yi Dong <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: hsiehjackson <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Adds megatron core support for our T5 model. This only works for pre-training, due to dynamic max lengths in the GLUE/XLNI datasets fine-tuning seems to be broken.
The fix going forward I think will be to pad to the maximum length to ensure the size of sequences are always the same. This is a requirement of the iterator object now used in the
training_step
.I also need to confirm that the weights of the model are the same between main and this branch.
Collection: NLP
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information