From b95a1698d5e69bcd2f196a9b9da08f58285aa92c Mon Sep 17 00:00:00 2001 From: Micha Livne Date: Fri, 11 Aug 2023 13:59:45 -0400 Subject: [PATCH] Megatron hidden transformations (#6332) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [TTS] bugfix for missing configs. (#4725) Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * docs typo fix Signed-off-by: Oleksii Kuchaiev * Fix pynini install in TTS tutorials (#4729) Signed-off-by: Jocelyn Huang Signed-off-by: Jocelyn Huang * Fix ASR notebooks (#4738) Signed-off-by: smajumdar Signed-off-by: smajumdar * Multilingual VAD model (#4734) * add ngc link Signed-off-by: fayejf * add tuned VAD config on ASR data Signed-off-by: fayejf * yaml note Signed-off-by: fayejf * update vad asr notebook with mVAD Signed-off-by: fayejf * update vad infer config comment Signed-off-by: fayejf * fix Signed-off-by: fayejf * mvad sd config for ch109 Signed-off-by: fayejf * update sd readme Signed-off-by: fayejf * add new mVAD model to doc Signed-off-by: fayejf * style fix Signed-off-by: fayejf * update sd tutorial with mVAD Signed-off-by: fayejf * typo fix Signed-off-by: fayejf Signed-off-by: fayejf * publish pretrained itn t5 model for English (#4748) Signed-off-by: Alexandra Antonova Signed-off-by: Alexandra Antonova Co-authored-by: Alexandra Antonova * Updated docs and doc paths (#4754) * Updated docs and doc paths Signed-off-by: Virginia Adams * Update Multitask_Prompt_and_PTuning.ipynb * Update README.rst * Changed branch name to use single quotes Signed-off-by: Virginia Adams Signed-off-by: Virginia Adams * fix bug relating to ddp strategy in joint intent slot classification tutorial (#4762) * [TTS] updated config with a German IPA phoneme tokenizer (#4756) * [TTS] added a German IPA phoneme tokenizer * [TTS][ASR] enabled customized arguments for trimming the leading and trailing silence. * [TTS] disabled spline interpolation for beta-binomial distribution. Let it generate align prior and save to disks. Use a new phoneme tokenizer. * [TTS] use consistent spline interpolation with fastpitch checkpoint when generating mel-spectrograms for hifigan finetune. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Update r1.11 to new heteronyms list (#4745) * Update configs to new heteronyms list * Remove old heteronyms list, add alt 'merchandise' pron to CMUdict * Update remaining references to old heteronyms list Signed-off-by: Jocelyn Huang Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * [TTS] Add multi-speaker German FastPitch and HiFiGAN NGC checkpoints (#4763) Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * [TTS] Add single male speaker German FastPitch and HiFiGAN NGC checkpoints (#4770) Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Update CMUdict with more recent 0.7b entries (#4768) Signed-off-by: Jocelyn Huang Signed-off-by: Jocelyn Huang Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Install pynini in docker container (#4733) Signed-off-by: Vladimir Bataev Signed-off-by: Vladimir Bataev Co-authored-by: Nithin Rao Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Eric Harper * Fix tutorial formatting (#4778) Signed-off-by: Jocelyn Huang * [TTS] deprecated old scripts for ljspeech. (#4780) * deprecated old scripts for ljspeech. * removed relevent function calls in TTS docs. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * update branch and typos (#4788) Signed-off-by: ericharper Signed-off-by: ericharper * Adding support for models trained with full context for cache-aware streaming. (#4687) * added support for models trained with full context. Signed-off-by: Vahid * fixed style. Signed-off-by: Vahid * dropped seq_range Signed-off-by: Vahid * fixed indexing in caching methods. Signed-off-by: Vahid * fixed code style. Signed-off-by: Vahid * fixed code style. Signed-off-by: Vahid * updated docs. Signed-off-by: Vahid * addressed comments. Signed-off-by: Vahid * fixed code style. Signed-off-by: Vahid * fixed code style. Signed-off-by: Vahid * fixed code style. Signed-off-by: Vahid * change frame-wise to cache-aware. Signed-off-by: Vahid * change frame-wise to cache-aware. Signed-off-by: Vahid * change frame-wise to cache-aware. Signed-off-by: Vahid * fixed code style. Signed-off-by: Vahid Signed-off-by: Vahid * Update megatron encoder decoder model to support py37 for colab (#4791) * [ASR] Add pretrained ASR models for Croatian (#4682) * [ASR] Add pretrained ASR models for Croatian Signed-off-by: Ante Jukić * Fix style for import Signed-off-by: Ante Jukić Signed-off-by: Ante Jukić Co-authored-by: Ante Jukić Co-authored-by: Nithin Rao Co-authored-by: Eric Harper Co-authored-by: Somshubra Majumdar * added/fixed export for Megatron models (#4712) * added/fixed export for Megatron models Signed-off-by: David Mosallanezhad * fixed style Signed-off-by: David Mosallanezhad * fixed FusedScaleMaskSoftmax in BioMegatron Signed-off-by: David Mosallanezhad * included comments Signed-off-by: David Mosallanezhad Signed-off-by: David Mosallanezhad Co-authored-by: David Mosallanezhad Co-authored-by: Eric Harper * update branch for qa notebook Signed-off-by: ericharper * Fix initializing weights from ptl ckpt with exclude (#4807) Signed-off-by: sam1373 Signed-off-by: sam1373 * Fix index error from addition of voiced_mask and p_voiced (#4811) Signed-off-by: Jocelyn Huang Signed-off-by: Jocelyn Huang * T5 prompt learning fixes (#4771) * RPE, hidden size and config fixes Signed-off-by: MaximumEntropy * Update to reflect new config names Signed-off-by: MaximumEntropy * Sentencepiece fixes Signed-off-by: MaximumEntropy * Style Signed-off-by: MaximumEntropy * Fix finetuning Signed-off-by: MaximumEntropy * Add encoder seq len to gpt Signed-off-by: MaximumEntropy * Style Signed-off-by: MaximumEntropy * Add finetune eval script Signed-off-by: MaximumEntropy * Fix name Signed-off-by: MaximumEntropy * Update Jenkinsfile Signed-off-by: MaximumEntropy * Update config Signed-off-by: MaximumEntropy * Fix CI test Signed-off-by: MaximumEntropy * Update check Signed-off-by: MaximumEntropy * Style Signed-off-by: MaximumEntropy * Backward compat Signed-off-by: MaximumEntropy * Update CI test Signed-off-by: MaximumEntropy * Split rank for Enc-Dec models Signed-off-by: MaximumEntropy * Address comments Signed-off-by: MaximumEntropy * Style Signed-off-by: MaximumEntropy Signed-off-by: MaximumEntropy Co-authored-by: Virginia Adams <78445382+vadam5@users.noreply.github.com> * G2P docs (#4841) * g2p docs added Signed-off-by: ekmb * fix references Signed-off-by: ekmb * address review feedback Signed-off-by: ekmb Signed-off-by: ekmb * Fix providing glue in seq2seq eval (#4843) * Fix providing glue in seq2seq eval Signed-off-by: MaximumEntropy * Fix Signed-off-by: MaximumEntropy * Style Signed-off-by: MaximumEntropy Signed-off-by: MaximumEntropy * Updated inference code and squad scripts (#4835) * Updated inference code and squad scripts Signed-off-by: Virginia Adams * Reverted GPT & T5 inference files back to use NLPDDPlugin Signed-off-by: Virginia Adams * Overwrite frozen LM to use fused adam Signed-off-by: Virginia Adams * Added padded vocab size Signed-off-by: Virginia Adams * Fixed val check interval value Signed-off-by: Virginia Adams * Python format fix Signed-off-by: Virginia Adams * Make t5 prompt learning preds write to file Signed-off-by: Virginia Adams * Added back dp=1 check Signed-off-by: Virginia Adams Signed-off-by: Virginia Adams Co-authored-by: Sandeep Subramanian * Update README.rst * Fix uppercasing mismatch for IPA heteronyms (#4860) Signed-off-by: Jocelyn Huang Signed-off-by: Jocelyn Huang * Set the number of workers to 0 for validation and test sets in all enc-dec models (#4790) * Set workers to 0 for validation and test Signed-off-by: MaximumEntropy * Revert pin memory Signed-off-by: MaximumEntropy * Style Signed-off-by: MaximumEntropy Signed-off-by: MaximumEntropy Co-authored-by: Sean Naren * Fix mha (#4866) * fix bug in mha forward function related to cache update return type Signed-off-by: Yang Zhang * fix lgtm Signed-off-by: Yang Zhang Signed-off-by: Yang Zhang Co-authored-by: Sean Naren * ipa bug fix (#4871) Signed-off-by: ekmb Signed-off-by: ekmb * Fix Megatron NMT consumed samples and ckpt_to_nemo split rank (#4884) * Fix nmt and ckpt_to_nemo Signed-off-by: MaximumEntropy * Style Signed-off-by: MaximumEntropy Signed-off-by: MaximumEntropy * added utf8 encoding (#4892) Signed-off-by: Virginia Adams Signed-off-by: Virginia Adams * 1. Applying the same patch to r1.11.0 (#4894) Signed-off-by: Micha Livne Signed-off-by: Micha Livne * Update tutorials.rst (#4897) * update readme with apex commit Signed-off-by: ericharper * Add support for Apex distributed Adam optimizer with GPT-3 (#4487) * Add support for Apex distributed Adam optimizer with GPT-3 Signed-off-by: Tim Moon * Fix bug in grad clipping with dist Adam Grad norm was computed over all params, not respecting model parallelism. Signed-off-by: Tim Moon * Fix bug with DDP initialization Signed-off-by: Tim Moon * Make distopt dependent on megatron_amp_o2 Signed-off-by: Tim Moon * Fix code formatting Signed-off-by: Tim Moon * Handle dist Adam in optimizer unit tests Signed-off-by: Tim Moon Signed-off-by: Tim Moon Co-authored-by: Eric Harper * update readme Signed-off-by: ericharper * update readme Signed-off-by: ericharper * latent model support * 1. Debugging. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. * update branch Signed-off-by: ericharper * fix replace_bos_with_pad not found (#6443) Signed-off-by: Abhinav Khattar * Support Swiglu in TP PP Conversion (#6437) * Support Swiglu in TP PP Conversion Signed-off-by: smajumdar * Guard activation Signed-off-by: smajumdar * Guard activation Signed-off-by: smajumdar --------- Signed-off-by: smajumdar * BERT pre-training mp fork to spawn (#6442) * change bert fork to spawn Signed-off-by: Abhinav Khattar * num_workers=0 fix Signed-off-by: Abhinav Khattar --------- Signed-off-by: Abhinav Khattar * Meagtron encoder decoder fix for empty validation outputs (#6459) * 1. Meagtron encoder decoder fix for empty validation outputs. Signed-off-by: Micha Livne * 1. Debugging. --------- Signed-off-by: Micha Livne Co-authored-by: Micha Livne * Added/updated new Conformer configs (#6426) * updated conf files. Signed-off-by: Vahid * added confs. Signed-off-by: Vahid * moved longconformer confs. Signed-off-by: Vahid * updated readme. Signed-off-by: Vahid * updated readme. Signed-off-by: Vahid * updated batch sizes and added fastconformer ctc streaming configs. Signed-off-by: Vahid * updated batch sizes. Signed-off-by: Vahid * added hybrid support. Signed-off-by: Vahid * added hybrid support. Signed-off-by: Vahid --------- Signed-off-by: Vahid * reduce workers on NMT CI (#6472) Signed-off-by: Abhinav Khattar * move to nvidia megatron repo (#6465) Signed-off-by: Abhinav Khattar * Megatron KERPLE positional embeddings (#6478) * [TTS] FastPitch adapter fine-tune and conditional layer normalization (#6416) [TTS] FastPitch adapter fine-tune and conditional layer normalization (#6416) --------- Signed-off-by: hsiehjackson Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [TTS] whitelist broken path fix. (#6412) * [TTS] whitelist broken path fix. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [TTS] FastPitch speaker encoder (#6417) * Add initial codes Signed-off-by: hsiehjackson * Remove wemb Signed-off-by: hsiehjackson * Fix import Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Restore aligner loss Signed-off-by: hsiehjackson * Add ConditionalInput Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix error and support pre-trained config Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Follow comments Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rename config Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Change copyright and random weight test Signed-off-by: hsiehjackson * Add initial codes Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: hsiehjackson * Fix import error Signed-off-by: hsiehjackson * Add initial codes Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: hsiehjackson * Fix dataset error Signed-off-by: hsiehjackson * Remove reference speaker embedding Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: hsiehjackson * Remove SV encoder Signed-off-by: hsiehjackson * Follow comments Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: hsiehjackson * Fix length type Signed-off-by: hsiehjackson * Fix append Signed-off-by: hsiehjackson * Move error msg Signed-off-by: hsiehjackson * Add look-up into speaker encoder Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: hsiehjackson * Add valueerror msg Signed-off-by: hsiehjackson * Move lookup Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: hsiehjackson * Remove unused Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: hsiehjackson * Fix error Signed-off-by: hsiehjackson * Rebase and Fix error Signed-off-by: hsiehjackson * Fix spk encoder Signed-off-by: hsiehjackson * Rename n_speakers Signed-off-by: hsiehjackson * Follow comments Signed-off-by: hsiehjackson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix n_speakers None error Signed-off-by: hsiehjackson --------- Signed-off-by: hsiehjackson Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Sharded manifests for tarred datasets (#6395) * testing sharded manifests Signed-off-by: Dima Rekesh * compatibility Signed-off-by: Dima Rekesh * proper fixes Signed-off-by: Dima Rekesh * adding flag tot convert_to_tarred_audio_dataset Signed-off-by: Dima Rekesh * shard_manifests conf param Signed-off-by: Dima Rekesh * propagating the shard_manifests param Signed-off-by: Dima Rekesh * propagating the shard_manifests param Signed-off-by: Dima Rekesh * distributed checks Signed-off-by: Dima Rekesh * typo Signed-off-by: Dima Rekesh * typo Signed-off-by: Dima Rekesh * fixes Signed-off-by: Dima Rekesh * fixes Signed-off-by: Dima Rekesh * fixes Signed-off-by: Dima Rekesh * fixes Signed-off-by: Dima Rekesh * fixes Signed-off-by: Dima Rekesh * fixes Signed-off-by: Dima Rekesh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes based on PR comments and tests Signed-off-by: Dima Rekesh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes to convert_to_tarred_audio_dataset.py Signed-off-by: Dima Rekesh * reversing manifest shards flag Signed-off-by: Dima Rekesh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tests Signed-off-by: Dima Rekesh * excluding manifests from webdataset url expansion Signed-off-by: Dima Rekesh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * expand manifest paths before attempting to cache from datastore Signed-off-by: Dima Rekesh * explicit use of UTF-8 for manifest i/o Signed-off-by: Dima Rekesh --------- Signed-off-by: Dima Rekesh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Update wfst_text_normalization.rst (#6374) Add Hungarian (incoming in NeMo-text-processing) Signed-off-by: Jim O’Regan * Support Swiglu in TP PP Conversion (#6437) (#6451) * Support Swiglu in TP PP Conversion * Guard activation * Guard activation --------- Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar * Update NeMo_TTS_Primer.ipynb (#6436) * Update NeMo_TTS_Primer.ipynb Changed a mistake in line 782. Instead of frequency band (ie. pitch) we should write frequency bin. Note that frequency bins in FFT are not related to pitch. Signed-off-by: Mostafa Ghorbandoost * Update NeMo_TTS_Primer.ipynb Corrected the description of spectrogram and mel spectrogram calculations in lines 782 & 783 and added a fourth point to the description and added a reference for more mathematical details at the end of this point. Signed-off-by: Mostafa Ghorbandoost --------- Signed-off-by: Mostafa Ghorbandoost * add rampup batch size support for Megatron GPT (#6424) * added rampup batch size support Signed-off-by: Dmytro Pykhtar * added tests for rampup batch size Signed-off-by: Dmytro Pykhtar * fixed the typos Signed-off-by: Dmytro Pykhtar * added assertions Signed-off-by: Dmytro Pykhtar * changed assertion rules Signed-off-by: Dmytro Pykhtar * deleted unused imports Signed-off-by: Dmytro Pykhtar * changed tests for rampup batch size Signed-off-by: Dmytro Pykhtar * updated rampup batch size tests Signed-off-by: Dmytro Pykhtar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed styling Signed-off-by: Dmytro Pykhtar * rampup batch size tests changes Signed-off-by: Dmytro Pykhtar --------- Signed-off-by: Dmytro Pykhtar Signed-off-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Co-authored-by: Dmytro Pykhtar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper * Meagtron encoder decoder fix for empty validation outputs (#6459) (#6461) * 1. Meagtron encoder decoder fix for empty validation outputs. * 1. Debugging. --------- Signed-off-by: Micha Livne Co-authored-by: Micha Livne Co-authored-by: Micha Livne * Code-Switching dataset creation - upgrading to aggregate tokenizer manifest format (#6448) * added functionality to create agg tokenizer compatible manifest for CS, flag to use this mode by default Signed-off-by: Kunal Dhawan * updated README with the new agg_tokenizer_manifest flag Signed-off-by: Kunal Dhawan * fixed typo in scripts/speech_recognition/code_switching/README.md Signed-off-by: Kunal Dhawan * changed agg_tokenizer_manifest to is_lid_manifest Signed-off-by: Kunal Dhawan --------- Signed-off-by: Kunal Dhawan Co-authored-by: Dima Rekesh * Added/updated new Conformer configs (#6426) (#6467) * Update script for ngram rnnt and hat beam search decoding (#6370) * add rnnt ngram beamsearch script Signed-off-by: andrusenkoau * add return encoding embedding option Signed-off-by: andrusenkoau * update script Signed-off-by: andrusenkoau * add rnnt and hat ngram decoding script Signed-off-by: andrusenkoau * add some parameters Signed-off-by: andrusenkoau * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add return_encoder_embeddings parameter to RNNTDecodingConfig Signed-off-by: andrusenkoau * replace return_encoder_embeddings parameter Signed-off-by: andrusenkoau * generalization of scipt behavior Signed-off-by: andrusenkoau * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove return_encoder_embeddings parameter Signed-off-by: andrusenkoau * remove return_encoder_embeddings parameter Signed-off-by: andrusenkoau * add manual encoder_embeddings calculation Signed-off-by: andrusenkoau * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix beam_width value to 8 Signed-off-by: Andrei Andrusenko <52885736+andrusenkoau@users.noreply.github.com> * fix rescoring description Signed-off-by: Andrei Andrusenko <52885736+andrusenkoau@users.noreply.github.com> --------- Signed-off-by: andrusenkoau Signed-off-by: Andrei Andrusenko <52885736+andrusenkoau@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Somshubra Majumdar * BERT pre-training mp fork to spawn (#6442) (#6454) * change bert fork to spawn * num_workers=0 fix --------- Signed-off-by: Abhinav Khattar Co-authored-by: Abhinav Khattar * fix replace_bos_with_pad not found (#6443) (#6450) Signed-off-by: Abhinav Khattar Co-authored-by: Abhinav Khattar * reduce workers on NMT CI (#6472) (#6474) Signed-off-by: Abhinav Khattar Co-authored-by: Abhinav Khattar * 1. Added KERPLE positional embeddings to encoder-decoder. Signed-off-by: Micha Livne * 1. Added a missing file. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Fixing commits. Signed-off-by: Micha Livne * 1. Debugging. * 1. Debugging. * 1. Debugging. * 1. Debugging. --------- Signed-off-by: hsiehjackson Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Dima Rekesh Signed-off-by: Jim O’Regan Signed-off-by: smajumdar Signed-off-by: Mostafa Ghorbandoost Signed-off-by: Dmytro Pykhtar Signed-off-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Signed-off-by: Micha Livne Signed-off-by: Kunal Dhawan Signed-off-by: andrusenkoau Signed-off-by: Andrei Andrusenko <52885736+andrusenkoau@users.noreply.github.com> Signed-off-by: Abhinav Khattar Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Dima Rekesh Co-authored-by: Jim O’Regan Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Somshubra Majumdar Co-authored-by: Mostafa Ghorbandoost Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Co-authored-by: Dmytro Pykhtar Co-authored-by: Eric Harper Co-authored-by: Micha Livne Co-authored-by: Kunal Dhawan Co-authored-by: Andrei Andrusenko <52885736+andrusenkoau@users.noreply.github.com> Co-authored-by: Abhinav Khattar * 1. Added external index sample. (#6462) Signed-off-by: Micha Livne * Fix cache aware hybrid bugs (#6466) * Update README to add core installation (#6488) * update README for megatron-core Signed-off-by: Abhinav Khattar * fix Signed-off-by: Abhinav Khattar --------- Signed-off-by: Abhinav Khattar * Fix typos (#6494) Signed-off-by: smajumdar * fix broken links r1.18.0 (#6501) * fix broken links Signed-off-by: Evelina * fix broken links Signed-off-by: Evelina --------- Signed-off-by: Evelina * 1. Fixed gaussian hidden transform. Signed-off-by: Micha Livne * 1. Finished updating hidden loss for MIM. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix custom forward_torch_softmax (#6512) Signed-off-by: Abhinav Khattar * [BugFix] Force _get_batch_preds() to keep logits in decoder timestamp… (#6500) * [BugFix] Force _get_batch_preds() to keep logits in decoder timestamps generator r1.18.0 Signed-off-by: Taejin Park * ignore keep_logits in FrameBatchASRLogits Signed-off-by: Taejin Park --------- Signed-off-by: Taejin Park * [TTS] fixed broken path. (#6514) Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * 1. Added a hiddens module. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix typos (#6523) (#6539) * Fix typos Signed-off-by: smajumdar * Fix typos Signed-off-by: smajumdar --------- Signed-off-by: smajumdar (cherry picked from commit 5468077f5127be1a4c88065de2544f4268b9a6e4) * added back the fast emit section to the configs. (#6540) * added back the fast emit section to the configs. Signed-off-by: Vahid * added back the fast emit section to the configs. Signed-off-by: Vahid --------- Signed-off-by: Vahid * Fix fp16 (#6543) Signed-off-by: MaximumEntropy * fix (#6529) Signed-off-by: Abhinav Khattar * pass .scale instead of scaler object to core (#6545) Signed-off-by: Abhinav Khattar Co-authored-by: Eric Harper * Change Megatron Enc Dec model to use persistent_workers (#6548) * persistent workers Signed-off-by: Abhinav Khattar * fix Signed-off-by: Abhinav Khattar --------- Signed-off-by: Abhinav Khattar Co-authored-by: Eric Harper * Add FastConformer Hybrid ASR models for EN, ES, IT, DE, PL, HR, UA, BY (#6549) * Added fastconfomer hybrid asr models for en, es, it, de, pl, hr, ua, by Signed-off-by: KunalDhawan * updated ASR docs with the fastconformer hybrid checkpoints Signed-off-by: KunalDhawan * added the fastconformer RNNT and CTC models Signed-off-by: KunalDhawan --------- Signed-off-by: KunalDhawan * Add scores for FastConformer models (#6557) Signed-off-by: smajumdar * Patch transcribe and support offline transcribe for hybrid model (#6550) Signed-off-by: fayejf * Not doing CastToFloat by default (#6524) * Not doing CastToFloat by default Signed-off-by: Boris Fomitchev * Added docustring Signed-off-by: Boris Fomitchev * Dummy commit Signed-off-by: Boris Fomitchev --------- Signed-off-by: Boris Fomitchev * temp rtd fix (#6568) Signed-off-by: Abhinav Khattar * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update manifest.py for speedup (#6565) * Update manifest.py Re-order the checks for faster processing audio filepaths that are already absolute paths Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> * Update manifest.py Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> --------- Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> Co-authored-by: Vahid Noroozi * Turn autocast off when precision is fp32 (#6554) * Turn autocast off when precision is fp32 Signed-off-by: Abhinav Khattar * address review Signed-off-by: Abhinav Khattar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Abhinav Khattar * merge Signed-off-by: Abhinav Khattar --------- Signed-off-by: Abhinav Khattar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper * More streaming conformer export fixes (#6567) Signed-off-by: Greg Clark Co-authored-by: Vahid Noroozi * Fix batch size reconf for T5 FT for multi-validation (#6582) Signed-off-by: Abhinav Khattar * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Updated Megatron LM encoder/decoder to use cfg for hiddens. Signed-off-by: Micha Livne * 1. Added support to register externalhidden loss / transforms. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make tensor split contiguous (#6580) Signed-off-by: Abhinav Khattar * Patches from main to r1.18.0 for Virtual Parallel (#6592) * Add interleaved pp support (#6498) * Add support for Virtual Pipeline Parallel conversion Signed-off-by: smajumdar * Add support for Virtual Pipeline Parallel conversion Signed-off-by: smajumdar * Switch to megatron core Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: smajumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 892987169ef277f328e15b71a5a0c9bd961c8ee7) * Add patches for Virtual Parallel conversion (#6589) * Add patches for Virtual Parllel conversion Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: smajumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit 1d813a372ab51688e3af6395d905a4c0366ffd23) * Documentation for ASR-TTS models (#6594) * Add docs about hybrid ASR-TTS models Signed-off-by: Vladimir Bataev * Add docs about text-only datasets Signed-off-by: Vladimir Bataev * Add docs about ASR-TTS checkpoints Signed-off-by: Vladimir Bataev * Add docs about ASR-TTS configs and training Signed-off-by: Vladimir Bataev * Clean up Signed-off-by: Vladimir Bataev * ASR-TTS docs: add to api, fix imports Signed-off-by: Vladimir Bataev * Clean up Signed-off-by: Vladimir Bataev * Wrap optional import Signed-off-by: Vladimir Bataev * Revert general ASR import Signed-off-by: Vladimir Bataev --------- Signed-off-by: Vladimir Bataev * Update SDP docs (#6485) * add info about SDP e.g. processor classes in docs Signed-off-by: Elena Rastorgueva * add link to SDP docs in README Signed-off-by: Elena Rastorgueva * address code review comments and add SDP overview diagram Signed-off-by: Elena Rastorgueva * Fix spelling typo Signed-off-by: Elena Rastorgueva --------- Signed-off-by: Elena Rastorgueva * Create dummy iters to satisy len checks (#6600) Signed-off-by: Abhinav Khattar * 1. Debugging. Signed-off-by: Micha Livne * Restore GPT support for interleaved pipeline parallelism (#6528) * Restore logic for data-parallel communication with pipeline parallelism in GPT Signed-off-by: Tim Moon * Support dynamic attention masks in GPT Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Debug typos Signed-off-by: Tim Moon * Debug data iterator caching with interleaved pipeline parallelism Each model chunk accesses the data iterator multiple times, so we need to cache multiple samples. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update Megatron-LM commit Signed-off-by: Tim Moon * Distinguish between list of data iterators and data iterator that is a list Signed-off-by: Tim Moon * Create dummy iters to satisy len checks Signed-off-by: Abhinav Khattar * Kludge while waiting for Megatron-LM update Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * set transformers offline to avoid rate limiting Signed-off-by: ericharper --------- Signed-off-by: Tim Moon Signed-off-by: Eric Harper Signed-off-by: Abhinav Khattar Signed-off-by: ericharper Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper Co-authored-by: Abhinav Khattar * Patch transcribe_util for steaming mode and add wer calculation back to inference scripts (#6601) * fix write Signed-off-by: fayejf * decoding ctc Signed-off-by: fayejf * temp set rnnt decoding return_best_hypothesis to true Signed-off-by: fayejf * add wer cal back to transcribe_speech as requested Signed-off-by: fayejf * add wer cal back to speech_to_text_buffered_infer_rnnt as requested Signed-off-by: fayejf * add wer cal back to speech_to_text_buffered_infer_ctc as requested Signed-off-by: fayejf * style fix Signed-off-by: fayejf * reflect change in asr_evaluator Signed-off-by: fayejf * reflect som and vahid comment Signed-off-by: fayejf * remove return_best_hy=true in transcribe_speech Signed-off-by: fayejf * no text skip Signed-off-by: fayejf --------- Signed-off-by: fayejf * 1. Added example conf YAML. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Added support in tensor_parallel. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add hat image to docs (#6619) Signed-off-by: andrusenkoau * update core commit hash in readme (#6622) Signed-off-by: Abhinav Khattar * Patch decoding for PC models (#6630) * Patch decoding logic for PC models Signed-off-by: smajumdar * Patch decoding logic for PC models Signed-off-by: smajumdar --------- Signed-off-by: smajumdar * Fix wer.py where 'errors' variable was not set (#6633) Fix wer.py where 'errors' variable was not set when both reference and hypothesis are empty strings Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> * fix att_context_size bug for older models. (#6635) Signed-off-by: Vahid * Add megatron_core to requirements (#6639) * add megatron_core to requirements Signed-off-by: ericharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: ericharper Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Remove from jenkins (#6641) * add megatron_core to requirements Signed-off-by: ericharper * remove from jenkins Signed-off-by: ericharper --------- Signed-off-by: ericharper * remove dup (#6643) Signed-off-by: ericharper * 1. Fixed config to use names, and added better error messages. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Added support to pass extra data to hiddens for loss computation. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * 1. Working on passing extra data to hiddnes. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * 1. Fixed support in loading .nemo without hiddnes module. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Improved and fixed logging of validation and testing. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Fixed training logging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Fixed logging of hidden loss. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Fixed logging names. 2. Added logging to hiddens and tokens loss. Signed-off-by: Micha Livne * 1. Fixed conflicts. Signed-off-by: Micha Livne * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne * 1. Debugging. Signed-off-by: Micha Livne --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Oleksii Kuchaiev Signed-off-by: Jocelyn Huang Signed-off-by: smajumdar Signed-off-by: fayejf Signed-off-by: Alexandra Antonova Signed-off-by: Virginia Adams Signed-off-by: Vladimir Bataev Signed-off-by: ericharper Signed-off-by: Vahid Signed-off-by: Ante Jukić Signed-off-by: David Mosallanezhad Signed-off-by: sam1373 Signed-off-by: MaximumEntropy Signed-off-by: ekmb Signed-off-by: Yang Zhang Signed-off-by: Micha Livne Signed-off-by: Tim Moon Signed-off-by: Abhinav Khattar Signed-off-by: smajumdar Signed-off-by: Micha Livne Signed-off-by: hsiehjackson Signed-off-by: Dima Rekesh Signed-off-by: Jim O’Regan Signed-off-by: Mostafa Ghorbandoost Signed-off-by: Dmytro Pykhtar Signed-off-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Signed-off-by: Kunal Dhawan Signed-off-by: andrusenkoau Signed-off-by: Andrei Andrusenko <52885736+andrusenkoau@users.noreply.github.com> Signed-off-by: Evelina Signed-off-by: Taejin Park Signed-off-by: KunalDhawan Signed-off-by: Boris Fomitchev Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> Signed-off-by: Greg Clark Signed-off-by: Elena Rastorgueva Signed-off-by: Eric Harper Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Oleksii Kuchaiev Co-authored-by: Jocelyn Co-authored-by: Somshubra Majumdar Co-authored-by: fayejf <36722593+fayejf@users.noreply.github.com> Co-authored-by: bene-ges <61418381+bene-ges@users.noreply.github.com> Co-authored-by: Alexandra Antonova Co-authored-by: Virginia Adams <78445382+vadam5@users.noreply.github.com> Co-authored-by: Zhilin Wang Co-authored-by: Vladimir Bataev Co-authored-by: Nithin Rao Co-authored-by: Eric Harper Co-authored-by: Vahid Noroozi Co-authored-by: anteju <108555623+anteju@users.noreply.github.com> Co-authored-by: Ante Jukić Co-authored-by: David Co-authored-by: David Mosallanezhad Co-authored-by: Samuel Kriman Co-authored-by: Sandeep Subramanian Co-authored-by: Evelina <10428420+ekmb@users.noreply.github.com> Co-authored-by: Sean Naren Co-authored-by: Yang Zhang Co-authored-by: Sean Naren Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Neha Tadimeti Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Abhinav Khattar Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Co-authored-by: Dima Rekesh Co-authored-by: Jim O’Regan Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Mostafa Ghorbandoost Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Co-authored-by: Dmytro Pykhtar Co-authored-by: Kunal Dhawan Co-authored-by: Andrei Andrusenko <52885736+andrusenkoau@users.noreply.github.com> Co-authored-by: Taejin Park Co-authored-by: Boris Fomitchev Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> Co-authored-by: Greg Clark Co-authored-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> --- .../conf/megatron_hiddens_base_config.yaml | 43 +++ .../language_modeling/megatron_bert_model.py | 6 +- .../megatron_finetune_model.py | 42 +-- .../megatron_lm_encoder_decoder_model.py | 219 +++++++++---- .../megatron_t5_prompt_learning_model.py | 5 +- .../machine_translation/megatron_nmt_model.py | 37 +-- .../megatron/megatron_encoder_decoder.py | 51 ++- .../megatron/token_level_encoder_decoder.py | 43 ++- .../megatron/transformations/__init__.py | 16 + .../transformations/megatron_hidden_loss.py | 189 +++++++++++ .../megatron_hidden_transform.py | 170 ++++++++++ .../transformations/megatron_hiddens.py | 310 ++++++++++++++++++ .../nlp/modules/common/megatron/utils.py | 8 +- 13 files changed, 993 insertions(+), 146 deletions(-) create mode 100644 examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml create mode 100644 nemo/collections/nlp/modules/common/megatron/transformations/__init__.py create mode 100644 nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_loss.py create mode 100644 nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_transform.py create mode 100644 nemo/collections/nlp/modules/common/megatron/transformations/megatron_hiddens.py diff --git a/examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml b/examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml new file mode 100644 index 000000000000..d63255d50ed3 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml @@ -0,0 +1,43 @@ +# this file main purpose is documentation, and it should not be used directly +enc_output_name: z # name of key in hidden transforms output to pass to decoder (e.g., z for VAE/MIM) +tokens_loss_weight: 1.0 # weight of tokens loss (if not specified defaults to 1.0) +# the lists below are useful for adding multiple transforms and losses according to order +# if order is not important, you can use a single dictionary in the list with multiple keys +transform: # a list of dictionaries of transforms (or a joint dictionary) to apply to hiddens (list enforces order) + # - : # name of transform + # cls_name: # class path name + # : # transform parameters + # ... + - q_z_given_x: # Gaussian posterior with reparameterization + cls_name: cond_gaussian # class path name + hidden_size: 512 # hidden size of the encoder + min_logvar: -6.0 # minimum log variance + - logP_cls: + cls_name: guided_cls + input_name: hiddens + attr_name: logP + - QED_cls: + cls_name: guided_cls + input_name: hiddens + attr_name: QED +loss: # a list of dictionaries of loss terms (or a joint dictionary) to add to reconstruction loss (list enforces order) + # - : # name of loss + # cls_name: # class path name + # : # loss parameters + # ... + # below is example where order of losses does not matter so a single item in list is enough + - mim: # A-MIM example + cls_name: a_mim + loss_weight: 1.0 # weight of the MIM latent loss + vae: # VAE example + cls_name: vae + min_kl_value: null # minimum KL value if a float is provided + loss_weight: 1e-2 # weight of KL term in loss + logP_cls: + cls_name: guided_cls_loss + input_name: logP + loss_weight: 1.0 + QED_cls: + cls_name: guided_cls_loss + input_name: logP + loss_weight: 1.0 diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index ddd46e681f94..82a3cbf36b64 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -249,7 +249,7 @@ def loss_func(output_tensor): lm_loss = loss_dict['lm loss'] loss = lm_loss reduced_loss = average_losses_across_data_parallel_group([loss, lm_loss]) - return loss, {'avg': reduced_loss} + return loss, {'loss': reduced_loss} return output_tensor, loss_func @@ -334,7 +334,7 @@ def training_step(self, dataloader_iter, batch_idx): ) if losses_reduced_per_micro_batch: - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensors_list = [loss_reduced['loss'] for loss_reduced in losses_reduced_per_micro_batch] loss_tensor = torch.vstack(loss_tensors_list) loss_mean = loss_tensor.mean(axis=0) else: @@ -447,7 +447,7 @@ def validation_step(self, dataloader_iter, batch_idx): ) if losses_reduced_per_micro_batch: - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensors_list = [loss_reduced['loss'] for loss_reduced in losses_reduced_per_micro_batch] loss_tensor = torch.vstack(loss_tensors_list) loss_mean = loss_tensor.mean(axis=0) else: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py index 9fce0d52c4a1..d854505d1f74 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py @@ -276,47 +276,27 @@ def _reconfigure_and_process_inference_batch(self, batch, ds_config): def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): """ Dataloader produces a global batch which is turned into a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ - # Get seq length of batch batch = next(dataloader_iter) if isinstance(batch, dict): # convert to list if not already converted. batch = self._process_batch(batch) - _, seq_length = batch[0].shape - _, dec_seq_length = batch[1].shape - tensor_shape = [seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size] - data_iter = get_iterator_k_split(batch, get_num_microbatches()) + # Get seq length of batch + encoder_seq_length = batch[0].size(1) + decoder_seq_length = batch[1].size(1) - fwd_bwd_function = get_forward_backward_func() + tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size] + data_iter = get_iterator_k_split(batch, get_num_microbatches()) - losses_reduced_per_micro_batch = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(), + return self._execute_fwd_bwd_function( data_iterator=data_iter, - model=[self.enc_dec_model], - num_microbatches=get_num_microbatches(), forward_only=forward_only, tensor_shape=tensor_shape, - decoder_seq_length=dec_seq_length, - dtype=self.autocast_dtype, - grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None, - sequence_parallel=self.cfg.get('sequence_parallel', False), - enable_autocast=self.enable_autocast, + decoder_seq_length=decoder_seq_length, ) - # only the last stages of the pipeline return losses - if losses_reduced_per_micro_batch: - # average loss across micro batches - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] - loss_tensor = torch.concat(loss_tensors_list) - loss_mean = loss_tensor.mean() - else: - # we're not on the last pipeline stage so no losses - loss_mean = torch.tensor(0.0).cuda() - - return loss_mean - def inference_step(self, dataloader_iter, batch_idx: int, mode: str, dataloader_idx=0): # Add try except since dataloader_iter in PTL 2.0 doesnt catch the end of the iterator try: @@ -366,12 +346,16 @@ def inference_step(self, dataloader_iter, batch_idx: int, mode: str, dataloader_ _ = metric(pred, label) outputs = { - 'loss': loss, 'preds': preds_text, 'labels': labels_text, 'categories': categories, 'inputs': input_text, } + + if isinstance(loss, dict): + outputs.update(loss) + else: + outputs['loss'] = loss if mode == 'validation': if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: self.validation_step_outputs[dataloader_idx].append(outputs) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index f8f8fe808612..ff4da0f624ed 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -240,7 +240,6 @@ def _populate_encoder_decoder_configs_for_backward_compatibility(self, cfg): ) # For models before separate encoder/decoder configs, tokens_head_bias was always True. def model_provider_func(self, pre_process, post_process, add_encoder, add_decoder): - # TODO: create get_encoder_decoder_model()here for different losses (e..g, nll, vae, mim) if not hasattr(self.cfg, 'encoder') or not hasattr(self.cfg, 'decoder'): logging.warning( 'Could not find encoder or decoder in config. This is probably because of restoring an old checkpoint. Copying shared model configs to encoder and decoder configs.' @@ -282,6 +281,7 @@ def model_provider_func(self, pre_process, post_process, add_encoder, add_decode share_token_embeddings=self.cfg.get('share_token_embeddings', True), share_decoder_tokens_head_embeddings=self.cfg.get('share_decoder_tokens_head_embeddings', True), tokens_head_bias=self.cfg.get('tokens_head_bias', True), + hiddens_cfg=self.cfg.get('hiddens', None), ) return model @@ -313,42 +313,54 @@ def forward( return output_tensor - def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + def _execute_fwd_bwd_function(self, data_iterator, forward_only, tensor_shape, decoder_seq_length): """ - Dataloader produces a global batch which is turned into a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + An auxiliary function that executes the fwd_bwd_step function and parse the returned values. """ - # Get seq length of batch - tensor_shape = [self.max_encoder_seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size] - fwd_bwd_function = get_forward_backward_func() losses_reduced_per_micro_batch = fwd_bwd_function( forward_step_func=self.get_forward_output_and_loss_func(), - data_iterator=dataloader_iter, + data_iterator=data_iterator, model=[self.enc_dec_model], num_microbatches=get_num_microbatches(), forward_only=forward_only, tensor_shape=tensor_shape, - decoder_seq_length=self.max_decoder_seq_length, + decoder_seq_length=decoder_seq_length, dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None, + sequence_parallel=self.cfg.get('sequence_parallel', False), enable_autocast=self.enable_autocast, ) # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: - # average loss across micro batches - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] - loss_tensor = torch.concat(loss_tensors_list) - loss_mean = loss_tensor.mean() + mean_loss_dict = {} + for k in losses_reduced_per_micro_batch[0].keys(): + # average loss across micro batches + mean_loss_dict[k] = torch.stack( + [loss_reduced[k] for loss_reduced in losses_reduced_per_micro_batch] + ).mean() else: - if forward_only: - loss_mean = [] - else: - loss_mean = torch.tensor(0.0).cuda() + loss_mean = torch.tensor(0.0).cuda() + mean_loss_dict = {"loss": loss_mean} - return loss_mean + return mean_loss_dict + + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + """ + Dataloader produces a global batch which is turned into a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + """ + # Get seq length of batch + tensor_shape = [self.max_encoder_seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size] + + return self._execute_fwd_bwd_function( + data_iterator=dataloader_iter, + forward_only=forward_only, + tensor_shape=tensor_shape, + decoder_seq_length=self.max_decoder_seq_length, + ) def training_step(self, dataloader_iter, batch_idx): """ @@ -362,7 +374,7 @@ def training_step(self, dataloader_iter, batch_idx): # we zero grads here because we also call backward in the megatron fwd/bwd functions self._optimizer.zero_grad() - loss_mean = self.fwd_bwd_step(dataloader_iter, batch_idx, False) + loss_dict = self.fwd_bwd_step(dataloader_iter, batch_idx, False) if self.with_distributed_adam: # synchronize asynchronous grad reductions @@ -386,14 +398,16 @@ def training_step(self, dataloader_iter, batch_idx): ## logging # we can only log on one rank if it is rank zero so we broadcast from last rank # we can avoid this broadcast by updating the PTL log function to accept specific ranks - torch.distributed.broadcast(loss_mean, get_last_rank()) + for k, v in loss_dict.items(): + torch.distributed.broadcast(v, get_last_rank()) + n = f'reduced_train_{k}' + self.log(n, v, prog_bar=n.endswith("_loss"), rank_zero_only=True, batch_size=1) if self.cfg.precision == 16: loss_scale = self.trainer.precision_plugin.scaler._scale if loss_scale is not None: self.log('loss_scale', loss_scale, batch_size=1) - self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr, rank_zero_only=True, batch_size=1) self.log( @@ -407,7 +421,7 @@ def training_step(self, dataloader_iter, batch_idx): rank_zero_only=True, batch_size=1, ) - return loss_mean + return loss_dict @property def max_decoder_seq_length(self) -> int: @@ -556,16 +570,26 @@ def _process_batch(self, global_batch: Dict[str, torch.Tensor]) -> List[torch.Te global_batch["labels"], global_batch["enc_mask"], global_batch["dec_mask"], + global_batch.get('data', None), ] def get_forward_output_and_loss_func(self): def fwd_output_and_loss_func(dataloader_iter, model): batch = next(dataloader_iter) + # convert to list if not already converted. if isinstance(batch, dict): # convert to list if not already converted. batch = self._process_batch(batch) - batch = [x.cuda(non_blocking=True) for x in batch] - encoder_input_ids, decoder_input_ids, loss_mask, lm_labels, encoder_attn_mask, decoder_attn_mask = batch + batch = [x.cuda(non_blocking=True) if torch.is_tensor(x) else x for x in batch] + ( + encoder_input_ids, + decoder_input_ids, + loss_mask, + lm_labels, + encoder_attn_mask, + decoder_attn_mask, + batch_data, + ) = batch output = model( encoder_input_ids, # enc_input_ids @@ -574,12 +598,32 @@ def fwd_output_and_loss_func(dataloader_iter, model): decoder_attn_mask, # dec_attn_mask None, # token_type_ids lm_labels, # labels + batch_data, # batch_data ) def loss_func(output_tensor): - loss = self.loss_func(loss_mask, output_tensor) - reduced_loss = average_losses_across_data_parallel_group([loss]) - return loss, {'avg': reduced_loss} + if isinstance(output_tensor, dict): + # handle loss of hidden transformations + loss_dict = output_tensor + output_tensor = loss_dict.pop("output") + # compute reconstruction (tokens) only loss from per-token reconstruction loss + tokens_loss = self.loss_func(loss_mask, output_tensor) + loss_dict["tokens_loss"] = tokens_loss + tokens_loss_weight = loss_dict.get("tokens_loss_weight", 1.0) + # compute total loss + loss = loss_dict["loss"] = loss_dict["hiddens_loss"] + tokens_loss_weight * tokens_loss + # average losses across data parallel group + loss_dict = { + k: average_losses_across_data_parallel_group([v.mean()]) for k, v in loss_dict.items() + } + else: + # compute reconstruction (tokens) only loss from per-token reconstruction loss + loss = self.loss_func(loss_mask, output_tensor) + # average losses across data parallel group + reduced_loss = average_losses_across_data_parallel_group([loss]) + loss_dict = {'loss': reduced_loss} + + return loss, loss_dict return output, loss_func @@ -645,75 +689,104 @@ def _get_forward_output_only_func(self, arg_names, output_name, **kwargs): def fwd_output_only_func(dataloader_iter, model): batch = next(dataloader_iter) - batch = [x.cuda(non_blocking=True) for x in batch] + batch = [x.cuda(non_blocking=True) if torch.is_tensor(x) else x for x in batch] # map batch and shared args into forward args args = self._build_forward_args_from_kwargs(args_name=arg_names, args=batch, **kwargs) output = model(*args).contiguous() def id_func(output_tensor): + if isinstance(output_tensor, dict): + # handle loss of hidden transformations ("output" is the default output) + output_tensor = output_tensor["output"] + return output_tensor, {output_name: output_tensor} return output, id_func return fwd_output_only_func - def validation_step(self, dataloader_iter, batch_idx, dataloader_idx=0): + ########## + + def _test_validation_step(self, step_outputs, dataloader_iter, batch_idx, dataloader_idx=0): """ - return_values - if given, returns a dictionary with given keys and corresponding values + Shared code for validation and test step """ # Prefetch the dataloader_iter before fwd_bwd func to avoid PP rank 2 from waiting indefinitely with PP rank 1 reaches the end of dataloader_iter dataloader_iter, done = self._prefetch(dataloader_iter) if done: return - prefix = "test" if self.trainer.testing else "val" - loss = self.fwd_bwd_step(dataloader_iter, batch_idx, True) - if prefix == 'val': - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(loss) - else: - self.validation_step_outputs.append(loss) + + loss_dict = self.fwd_bwd_step(dataloader_iter, batch_idx, True) + step_outputs.append(loss_dict) + + return loss_dict + + def validation_step(self, dataloader_iter, batch_idx, dataloader_idx=0): + """ + return_values - if given, returns a dictionary with given keys and corresponding values + """ + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + step_outputs = self.validation_step_outputs[dataloader_idx] else: - if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(loss) - else: - self.test_step_outputs.append(loss) + step_outputs = self.validation_step_outputs - return loss + return self._test_validation_step( + step_outputs=step_outputs, + dataloader_iter=dataloader_iter, + batch_idx=batch_idx, + dataloader_idx=dataloader_idx, + ) - def on_validation_epoch_end(self): + def test_step(self, dataloader_iter, batch_idx, dataloader_idx=0): + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + step_outputs = self.test_step_outputs[dataloader_idx] + else: + step_outputs = self.test_step_outputs + + return self._test_validation_step( + step_outputs=step_outputs, + dataloader_iter=dataloader_iter, + batch_idx=batch_idx, + dataloader_idx=dataloader_idx, + ) + + def _test_validation_epoch_end(self, step_outputs, prefix): + """ + Shared logging for validation and test + """ # NOTE: we need to make sure outputs is not empty (this is a workaround for a bug in pytorch lightning (?)) - if not self.validation_step_outputs: - logging.warning("validation_epoch_end: outputs is empty") + if not step_outputs: + logging.warning(f"{prefix} epoch end: outputs is empty") return None - if parallel_state.is_pipeline_last_stage(): - # only the last pipeline parallel stages return loss - averaged_loss = torch.stack(self.validation_step_outputs).mean() + + # only the last pipeline parallel stages return loss + if parallel_state.is_pipeline_last_stage() and len(step_outputs): + averaged_loss = {k: torch.stack([x[k] for x in step_outputs]).mean() for k in step_outputs[0].keys()} else: - averaged_loss = torch.tensor(0.0).cuda() + # if we are here we assume that only loss is available and hidden transforms are disabled (since not supported in pipleline parallel) + averaged_loss = {'loss': torch.tensor(0.0).cuda()} # we can only log on one rank if it is rank zero so we broadcast from last rank - torch.distributed.broadcast(averaged_loss, get_last_rank()) - self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) - self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1) - self.validation_step_outputs.clear() # free memory + for k, v in averaged_loss.items(): + torch.distributed.broadcast(v, get_last_rank()) + averaged_loss[k] = v + n = f'{prefix}_{k}' + # log only '*_loss' values in progress bar + self.log(n, v, prog_bar=(n.endswith("_loss")), rank_zero_only=True, batch_size=1) + + # free memory + step_outputs.clear() + return averaged_loss - def test_step(self, dataloader_iter, batch_idx): - return self.validation_step(dataloader_iter, batch_idx) + def on_validation_epoch_end(self): + # FIXME: do we need this? 'global_step' is logged in training_step + self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1) + return self._test_validation_epoch_end(step_outputs=self.validation_step_outputs, prefix="val",) def on_test_epoch_end(self): - if parallel_state.is_pipeline_last_stage(): - # only the last pipeline parallel stages return loss - averaged_loss = torch.stack(self.test_step_outputs).mean() - else: - averaged_loss = torch.tensor(0.0).cuda() - - # we can only log on one rank if it is rank zero so we broadcast from last rank - torch.distributed.broadcast(averaged_loss, get_last_rank()) - self.log('test_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) - self.test_step_outputs.clear() # free memory - return averaged_loss + return self._test_validation_epoch_end(step_outputs=self.test_step_outputs, prefix="test",) def loss_func(self, loss_mask, tokens_loss): """ @@ -937,11 +1010,14 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] logging.info(f"response: {response}") return response - def encode(self, tokens_enc, enc_mask, encoder_input=None, reconfigure_microbatch=True): + def encode(self, tokens_enc, enc_mask, encoder_input=None, batch_data=None, reconfigure_microbatch=True): """ tokens_enc - encoder input tokens enc_mask - corresponding mask encoder_input - encoder input (bypass tokens), if given tokens_enc can be None. + batch_data - passed directly to all hidden transformations and losses. + Can be used to pass additional data like class label. + Format is not defined and should match the expected format of the used hiddens modules. """ # Check whether the DDP is initialized. This is needed when running inference outside of training loop. if parallel_state.is_unitialized(): @@ -987,8 +1063,8 @@ def dummy(): # build input arguments description if tokens_enc is not None: - batch_for_pipeline = [tokens_enc, enc_mask] - arg_names = ['enc_input_ids', 'enc_attn_mask'] + batch_for_pipeline = [tokens_enc, enc_mask, batch_data] + arg_names = ['enc_input_ids', 'enc_attn_mask', 'batch_data'] else: if encoder_input is None: raise ValueError("At least one of tokens_enc and encoder_input must be provided with not None value") @@ -1060,6 +1136,7 @@ def decode( ignore_ids=[], bos_id=None, # If bos=None, will use tokenizer.bos_id unless explicitly set to something else. predicted_tokens_dec=None, + batch_data=None, sampling_method: str = "greedy-search", sampling_kwargs: dict = {}, ): @@ -1168,8 +1245,8 @@ def dummy(): dec_mask = predicted_tokens_dec != tokenizer.pad_id dec_mask[:, 0] = 1 # Make sure you never mask the first token even if it is . - batch_for_pipeline = [enc_output, enc_output_attn_mask, predicted_tokens_dec, dec_mask] - arg_names = ['enc_output', 'enc_output_attn_mask', 'dec_input_ids', 'dec_attn_mask'] + batch_for_pipeline = [enc_output, enc_output_attn_mask, predicted_tokens_dec, dec_mask, batch_data] + arg_names = ['enc_output', 'enc_output_attn_mask', 'dec_input_ids', 'dec_attn_mask', 'batch_data'] forward_step_func = self._get_forward_output_only_func(arg_names=arg_names, output_name="logits") fwd_bwd_func = get_forward_backward_func() diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py index 2f667d815827..c7e63e1c5a59 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py @@ -203,7 +203,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: # average loss across micro batches - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensors_list = [loss_reduced['loss'] for loss_reduced in losses_reduced_per_micro_batch] loss_tensor = torch.concat(loss_tensors_list) loss_mean = loss_tensor.mean() else: @@ -213,6 +213,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): return loss_mean def get_forward_output_and_loss_func(self): + # FIXME: consolidate this method into MegatronLMEncoderDecoderModel (or have a common base class) def fwd_output_and_loss_func(dataloader_iter, model): batch = next(dataloader_iter) batch = [x.cuda(non_blocking=True) for x in batch] @@ -226,7 +227,7 @@ def fwd_output_and_loss_func(dataloader_iter, model): def loss_func(output_tensor): loss = self.frozen_model.loss_func(loss_mask, output_tensor) reduced_loss = average_losses_across_data_parallel_group([loss]) - return loss, {'avg': reduced_loss} + return loss, {'loss': reduced_loss} return output_tensor, loss_func diff --git a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py index 0a233866cdff..3cd15100111e 100644 --- a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py +++ b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py @@ -303,34 +303,13 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size] data_iter = get_iterator_k_split(batch, get_num_microbatches()) - fwd_bwd_function = get_forward_backward_func() - - losses_reduced_per_micro_batch = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(), + return self._execute_fwd_bwd_function( data_iterator=data_iter, - model=[self.enc_dec_model], - num_microbatches=get_num_microbatches(), forward_only=forward_only, tensor_shape=tensor_shape, decoder_seq_length=decoder_seq_length, - dtype=self.autocast_dtype, - grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None, - sequence_parallel=self.cfg.get('sequence_parallel', False), - enable_autocast=self.enable_autocast, ) - # only the last stages of the pipeline return losses - if losses_reduced_per_micro_batch: - # average loss across micro batches - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] - loss_tensor = torch.concat(loss_tensors_list) - loss_mean = loss_tensor.mean() - else: - # we're not on the last pipeline stage so no losses - loss_mean = torch.tensor(0.0).cuda() - - return loss_mean - def eval_step(self, dataloader_iter, batch_idx, dataloader_idx=0): # Add try except since dataloader_iter in PTL 2.0 doesnt catch the end of iterables try: @@ -379,18 +358,22 @@ def eval_step(self, dataloader_iter, batch_idx, dataloader_idx=0): outputs=tokens_enc, tokenizer=self.encoder_tokenizer, processor=source_processor, ) - val_outputs = { + loss_dict = { 'inputs': encoder_inputs, 'translations': preds, 'ground_truths': labels, - 'loss': reduced_loss, } + if isinstance(reduced_loss, dict): + loss_dict.update(reduced_loss) + else: + loss_dict['loss'] = reduced_loss + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(val_outputs) + self.validation_step_outputs[dataloader_idx].append(loss_dict) else: - self.validation_step_outputs.append(val_outputs) + self.validation_step_outputs.append(loss_dict) - return val_outputs + return loss_dict except StopIteration: return diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py index 6a99e908f107..51ed1c7e7ef3 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py @@ -17,6 +17,7 @@ from nemo.collections.nlp.modules.common.megatron.megatron_perceiver_encoders import MegatronPerceiverEncoderModule from nemo.collections.nlp.modules.common.megatron.module import MegatronModule +from nemo.collections.nlp.modules.common.megatron.transformations.megatron_hiddens import MegatronHiddensModule from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults try: @@ -44,6 +45,7 @@ def __init__( encoder_attn_mask_type: AttnMaskType = None, decoder_attn_mask_type: AttnMaskType = None, hidden_steps: int = None, + hiddens_module: MegatronHiddensModule = None, # allows for hidden state transformations before the decoder ): super(MegatronTransformerEncoderDecoderModule, self).__init__() @@ -55,6 +57,12 @@ def __init__( f"hidden_steps cannot be None for perceiver encoders. It is needed to compute the encoder-decoder cross attention mask." ) + self.hiddens_module = hiddens_module + if self.hiddens_module is not None and not isinstance(self.hiddens_module, MegatronHiddensModule): + raise TypeError( + f"hiddens_module must be of type MegatronHiddensModule, but got {type(self.hiddens_module)} instead." + ) + # try to infer mask_type if not given if encoder_attn_mask_type is None: if encoder is None: @@ -83,6 +91,20 @@ def __init__( self._encoder_key = "encoder" self._decoder_key = "decoder" + self._hiddens_module = "hiddens_module" + + def get_hiddens_mask(self, enc_attn_mask): + """ + Returns the attention mask for the output of the encoder. + Required for fixed-size bottleneck models. + """ + if self.encoder is not None and isinstance(self.encoder, MegatronPerceiverEncoderModule): + # Attention mask is expected to be of shape [B x S] + hiddens_mask = torch.ones(enc_attn_mask.size(0), self.hidden_steps).to(enc_attn_mask.device) + else: + hiddens_mask = enc_attn_mask + + return hiddens_mask def encode( self, @@ -91,10 +113,11 @@ def encode( enc_layer_past=None, enc_get_key_value=False, enc_self_attention_relative_position_bias=None, + batch_data=None, ): + """Encodes embedder input using encoder""" if self.encoder is None: raise ValueError(f"Cannot call .encode(...) when self.encoder is None.") - """Encodes embedder input using encoder""" enc_output = self.encoder( enc_input=enc_input, enc_attn_mask=enc_attn_mask, @@ -103,6 +126,12 @@ def encode( enc_self_attention_relative_position_bias=enc_self_attention_relative_position_bias, ) + # apply hidden transformations if needed + if self.hiddens_module is not None: + enc_output = self.hiddens_module.apply_hidden_transforms( + {"hiddens": enc_output, "hiddens_mask": self.get_hiddens_mask(enc_attn_mask),}, batch_data=batch_data, + ) + return enc_output def decode( @@ -148,6 +177,7 @@ def forward( enc_self_attention_relative_position_bias=None, dec_self_attention_relative_position_bias=None, dec_cross_attention_relative_position_bias=None, + batch_data=None, ): # encoder if enc_output is None: @@ -158,6 +188,7 @@ def forward( enc_layer_past=enc_layer_past, enc_get_key_value=enc_get_key_value, enc_self_attention_relative_position_bias=enc_self_attention_relative_position_bias, + batch_data=batch_data, ) else: assert self.encoder_hidden_state is not None @@ -169,22 +200,21 @@ def forward( return enc_output # decoder - # Adjust encoder attention mask if encoder is a perceiver. - if self.encoder is not None and isinstance(self.encoder, MegatronPerceiverEncoderModule): - # Attention mask is expected to be of shape [B x S] and enc_output is of size [S x B x H]. - enc_attn_mask = torch.ones(enc_output.size(1), self.hidden_steps).to(enc_output.device) - dec_output = self.decode( dec_input=dec_input, dec_attn_mask=dec_attn_mask, - enc_output=enc_output, - enc_attn_mask=enc_attn_mask, + enc_output=enc_output["enc_output"] # enc_output is a dict if we used hidden transformations + if self.hiddens_module is not None + else enc_output, + # Adjust encoder attention mask if encoder is a perceiver. + enc_attn_mask=self.get_hiddens_mask(enc_attn_mask), dec_layer_past=dec_layer_past, dec_get_key_value=dec_get_key_value, dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias, dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, ) + # if self.hiddens_module is not None enc_output is a dict, else it is a torch.tensor return dec_output, enc_output def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): @@ -195,6 +225,9 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars= state_dict_[self._encoder_key] = self.encoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars) state_dict_[self._decoder_key] = self.decoder.state_dict_for_save_checkpoint(destination, prefix, keep_vars) + if self.hiddens_module is not None: + state_dict_[self._hiddens_module] = self.hiddens_module.state_dict(destination, prefix, keep_vars) + return state_dict_ def load_state_dict(self, state_dict, strict=True): @@ -202,3 +235,5 @@ def load_state_dict(self, state_dict, strict=True): self.encoder.load_state_dict(state_dict[self._encoder_key], strict=strict) self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) + if self.hiddens_module is not None: + self.hiddens_module.load_state_dict(state_dict[self._hiddens_module], strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index fc16295020fb..928b3f6e8d83 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -28,6 +28,7 @@ KERPLERelativePositionEmbedding, T5RelativePositionEmbedding, ) +from nemo.collections.nlp.modules.common.megatron.transformations.megatron_hiddens import get_hiddens_module from nemo.collections.nlp.modules.common.megatron.utils import ( ApexGuardDefaults, build_position_ids, @@ -124,6 +125,7 @@ def __init__( share_token_embeddings=True, share_decoder_tokens_head_embeddings=True, tokens_head_bias=True, + hiddens_cfg: DictConfig = None, # allows for hidden state transformations before the decoder ): super(MegatronTokenLevelEncoderDecoderModule, self).__init__() @@ -140,6 +142,7 @@ def __init__( self.share_token_embeddings = share_token_embeddings self.share_decoder_tokens_head_embeddings = share_decoder_tokens_head_embeddings self.tokens_head_bias = tokens_head_bias + self.hiddens_cfg = hiddens_cfg encoder_kv_channels, decoder_kv_channels = self._validate_config() @@ -388,8 +391,12 @@ def __init__( use_flash_attention=decoder_cfg.get('use_flash_attention', False), ) + hiddens_module = get_hiddens_module(hiddens_cfg) self.enc_dec_model = MegatronTransformerEncoderDecoderModule( - encoder=encoder, decoder=decoder, hidden_steps=encoder_cfg.get('hidden_steps', -1), + encoder=encoder, + decoder=decoder, + hidden_steps=encoder_cfg.get('hidden_steps', -1), + hiddens_module=hiddens_module, ) self._enc_dec_model_key = "enc_dec_model" @@ -455,6 +462,10 @@ def _validate_config(self): assert ( self.share_decoder_tokens_head_embeddings ), "Decoder token embeddings and the outputlayer must be shared when using pipeline model parallel size > 1" + assert ( + self.hiddens_cfg is None + ), "Hiddens module must not be enabled when using pipeline model parallel size > 1" + return encoder_kv_channels, decoder_kv_channels def set_input_tensor(self, input_tensor): @@ -493,6 +504,7 @@ def forward( dec_attn_mask=None, token_type_ids=None, labels=None, + batch_data=None, # additional data to be passed to hiddens module enc_output=None, # Result of running the entire encoder enc_output_attn_mask=None, enc_input=None, # Result of running encoder embedding only @@ -554,9 +566,11 @@ def forward( enc_layer_past=None, enc_get_key_value=False, enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, + batch_data=batch_data, ) else: enc_output = self.enc_dec_model.encoder_hidden_state + return enc_output else: if enc_output_attn_mask is None: @@ -598,10 +612,11 @@ def forward( enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, dec_self_attention_relative_position_bias=decoder_self_attention_relative_position_bias, dec_cross_attention_relative_position_bias=decoder_cross_attention_relative_position_bias, + batch_data=batch_data, ) if self.post_process and self.add_decoder: - dec_output, enc_output = output # [s, b, h] + dec_output, enc_output = output # [s, b, h], enc_output might be a dict if hiddens_module is used # project decoder output to vocabulary-size dimensions if self.share_decoder_tokens_head_embeddings: token_logits = self.tokens_head(dec_output, self.word_embeddings_weight()) @@ -609,6 +624,7 @@ def forward( token_logits = self.tokens_head(dec_output)[0] if labels is not None: + # compute loss here # [b, s] -> [s, b] labels = labels.transpose(0, 1).contiguous() @@ -625,11 +641,30 @@ def forward( # [s, b] -> [b, s] tokens_loss = tokens_loss.transpose(0, 1).contiguous() - return tokens_loss + # check if hiddens is used + if self.hiddens_cfg is not None: + loss_dict = self.enc_dec_model.hiddens_module.apply_loss_transforms( + outputs=enc_output, batch_data=batch_data, + ) + loss_dict["tokens_loss"] = tokens_loss + # We need to store default output in a known key, so that we can mimic default behaviour + loss_dict["output"] = tokens_loss + return loss_dict + else: + return tokens_loss else: + # else return token logits (and hiddens if needed) # [s, b, h] -> [b, s, h] token_logits = token_logits.transpose(0, 1).contiguous() - return token_logits + if self.hiddens_cfg is not None: + # return all hiddens and token logits + hiddens_dict = enc_output + hiddens_dict["token_logits"] = token_logits + # We need to store default output in a known key, so that we can mimic default behaviour + hiddens_dict["output"] = token_logits + return hiddens_dict + else: + return token_logits elif self.add_decoder and not self.add_encoder: decoder_output, _ = output diff --git a/nemo/collections/nlp/modules/common/megatron/transformations/__init__.py b/nemo/collections/nlp/modules/common/megatron/transformations/__init__.py new file mode 100644 index 000000000000..50a412ac2e13 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/transformations/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .megatron_hidden_loss import * +from .megatron_hidden_transform import * diff --git a/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_loss.py b/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_loss.py new file mode 100644 index 000000000000..f10c34d3fad3 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_loss.py @@ -0,0 +1,189 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import torch + +__all__ = ["MegatronBaseHiddenLoss", "MegatronAMIMHiddenLoss", "MegatronVAEHiddenLoss"] + + +class MegatronBaseHiddenLoss(torch.nn.Module): + """ + Base class to calculate hidden state loss. + Returned dict includes a loss value and additional outputs. + """ + + def __init__(self, loss_weight=1.0, name=""): + super().__init__() + self.name = name + self.loss_weight = float(loss_weight) + + def __str__(self): + return super().__str__() + f"(name={self.name})" + + def _validate_inputs(self, inputs): + """Validate inputs""" + # validate inputs + if not set(self.input_names).issubset(set(inputs.keys())): + raise ValueError(f"Inputs should contain {self.input_names}, but got {inputs.keys()}") + + @property + def input_names(self): + """Returns and caches input names""" + # we always expect hiddens_mask to be used to mask out loss of padded elements + return self._input_names() + ["hiddens_mask"] + + def _input_names(self): + """Add here all required inputs""" + return [] + + def _loss(self, inputs, batch_data=None): + """ + We expect input shapes to be [S x B x H] for Sequence, Batch, Hidden sizes (due to tensor parallel support). + We return a dictionary with dimensions [B x S x H], [B x S], [B], or []. + + Implement your own loss calculations. Must return "loss" key. + loss shape - [B x S] for Batch, Sequence sizes + batch_data - a dictionary of additional data that can be used to calculate loss + + Returns: + dict: a dictionary with loss and additional outputs (must include "loss" key) + example: {"loss": 0.0} + """ + raise NotImplementedError("Please implement loss calculations in child class") + + def loss(self, inputs, batch_data=None): + """A wrapper around custom _loss that adds a weighted loss and name to the output dict""" + self._validate_inputs(inputs) + + loss_dict = self._loss(inputs, batch_data=batch_data) + if "loss" not in loss_dict: + raise KeyError("Loss dict must contain 'loss' key") + + # average loss over active steps only. loss [B x S] + loss = loss_dict["loss"] + # hiddens_mask has shape of [B x S] + hiddens_mask = inputs["hiddens_mask"].to(loss) + loss = loss * hiddens_mask + # sequence level loss [B x S] -> batch level loss [B] + loss = loss.sum(dim=1) / hiddens_mask.sum(dim=1).clamp(min=1.0) + + # compute batch level weighted loss (scalar) + weighted_loss = loss.sum() * self.loss_weight + + # store updated losses + loss_dict["loss"] = loss + loss_dict["weighted_loss"] = weighted_loss + loss_dict["weight_loss"] = torch.tensor(self.loss_weight).to(weighted_loss) + + return loss_dict + + +class MegatronAMIMHiddenLoss(MegatronBaseHiddenLoss): + """ + Based on + Implements A-MIM loss with a unit Normal anchor. + A-MIM - asymmetric MIM (without sampling) + """ + + def __init__(self, loss_weight=1.0, hidden_aggregation_method="sum", name="mim"): + super().__init__( + name=name, loss_weight=loss_weight, + ) + + # allows to determine how to aggregate hidden loss over hidden dimension + self.hidden_aggregation_method = hidden_aggregation_method + + def _input_names(self): + """Add here all required inputs""" + return ["z", "z_log_prob"] + + def _loss(self, inputs, batch_data=None): + """ + We expect input shapes to be [S x B x H] for Sequence, Batch, Hidden sizes (due to tensor parallel support). + We return a dictionary with dimensions [B x S x H], [B x S], [B], or []. + + Implement your own loss calculations. Must return "loss" key. + loss shape - [B x S] for Batch, Sequence sizes + batch_data - a dictionary of additional data that can be used to calculate loss + """ + z = inputs["z"] + # get posterior + log_prob_q_z_given_x = inputs["z_log_prob"] + # compute log prob of anchor a unit Normal distribution + log_prob_P_z = -0.5 * (math.log(2 * math.pi) + z.pow(2)) + # aggregate over hidden dimension, default is sum + log_prob_P_z = getattr(log_prob_P_z, self.hidden_aggregation_method)(dim=-1) + + # A-MIM loss = log_p_x_given_z - 0.5 * (log_prob_P_z + log_prob_q_z_given_x) + # here we return only the hidden loss part + loss = -0.5 * (log_prob_P_z + log_prob_q_z_given_x) + + # return losses shaped [B x S] + return { + "loss": loss.transpose(0, 1), + "log_prob_P_z": log_prob_P_z.transpose(0, 1), + "log_prob_q_z_given_x": log_prob_q_z_given_x.transpose(0, 1), + } + + +class MegatronVAEHiddenLoss(MegatronBaseHiddenLoss): + """ + Based on + Implements VAE loss with a unit Normal anchor. + """ + + def __init__(self, loss_weight=1.0, min_kl_value=None, name="vae"): + super().__init__( + name=name, loss_weight=loss_weight, + ) + + # minimum value for KL divergence + if min_kl_value is None: + self.min_kl_value = min_kl_value + else: + self.min_kl_value = float(min_kl_value) + + def _input_names(self): + """Add here all required inputs""" + return ["z", "z_log_prob"] + + def _loss(self, inputs, batch_data=None): + """ + We expect input shapes to be [S x B x H] for Sequence, Batch, Hidden sizes (due to tensor parallel support). + We return a dictionary with dimensions [B x S x H], [B x S], [B], or []. + + Implement your own loss calculations. Must return "loss" key. + loss shape - [B x S] for Batch, Sequence sizes + batch_data - a dictionary of additional data that can be used to calculate loss + """ + z = inputs["z"] + # get posterior + log_prob_q_z_given_x = inputs["z_log_prob"] + # compute log prob of anchor a unit Normal distribution + log_prob_p_z = -0.5 * (math.log(2 * math.pi) + z.pow(2)).sum(dim=-1) + + # VAE loss = log_p_x_given_z - KL(q(z|x) || p(z)) + kl_div = log_prob_q_z_given_x - log_prob_p_z + # here we return only the hidden loss part + loss = -kl_div + + # return losses shaped [B x S] + return { + "loss": loss.transpose(0, 1), + "kl_div": kl_div.transpose(0, 1), + "log_prob_p_z": log_prob_p_z.transpose(0, 1), + "log_prob_q_z_given_x": log_prob_q_z_given_x.transpose(0, 1), + } diff --git a/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_transform.py b/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_transform.py new file mode 100644 index 000000000000..1c424a6a069b --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hidden_transform.py @@ -0,0 +1,170 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math + +import torch + +from nemo.collections.nlp.modules.common.megatron.utils import init_method_normal + +try: + from megatron.core import tensor_parallel + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +if not HAVE_MEGATRON_CORE: + raise NotImplementedError("Megatron Core is required to use Megatron Hidden Transformations") + +__all__ = ["MegatronBaseHiddenTransform", "MegatronGaussianHiddenTransform"] + + +class MegatronBaseHiddenTransform(torch.nn.Module): + """Base class to apply hidden state transformations""" + + def __init__(self, name=""): + super().__init__() + + self.name = name + + def __str__(self): + return super().__str__() + f"(name={self.name})" + + @property + def input_names(self): + """ + Provide here all required inputs + """ + return [] + + @property + def output_names(self): + """ + Provide here all generated outputs + """ + return [] + + def _validate_inputs(self, inputs): + """Validate inputs""" + # validate inputs + if not set(self.input_names).issubset(set(inputs.keys())): + raise ValueError(f"Inputs should contain {self.input_names}, but got {inputs.keys()}") + + def _transform(self, inputs, batch_data=None): + """ + Implement your own transformations. + We expect here shapes to be [S x B x H] for Sequence, Batch, Hidden sizes (due to tensor parallel support). + """ + # by default we pass inputs. + outputs = inputs.copy() + + return outputs + + def transform(self, inputs, batch_data=None): + """Apply a transformations on the inputs (hiddens is always assumed)""" + # validate inputs + self._validate_inputs(inputs) + + outputs = self._transform(inputs, batch_data=batch_data) + + return outputs + + +class MegatronGaussianHiddenTransform(MegatronBaseHiddenTransform): + """ + Constructes a diagonal Gaussian distribution from the hidden states and samples from it using reparametrization. + """ + + def __init__(self, hidden_size, min_logvar=-6, init_method_std=0.02, name="cond_gaussian"): + super().__init__(name=name) + # limit smaller allowed variance (for numerical stability) + self.min_logvar = min_logvar + self.hidden_size = hidden_size + # project hiddens to mean and log variance (support tensor parallelism) + self.hiddens_to_mean_logvar = tensor_parallel.ColumnParallelLinear( + hidden_size, + hidden_size * 2, + gather_output=True, + init_method=init_method_normal(init_method_std), + skip_bias_add=False, + use_cpu_initialization=False, + bias=True, + sequence_parallel_enabled=False, + async_tensor_model_parallel_allreduce=True, + gradient_accumulation_fusion=False, + ) + + @property + def input_names(self): + """ + Provide here all required inputs + """ + return ["hiddens", "hiddens_mask"] + + @property + def output_names(self): + """ + Provide here all generated outputs + """ + return ["z_mean", "z_logvar", "z", "z_log_prob"] + + def _transform(self, inputs, batch_data=None): + """ + We expect here shapes to be [S x B x H] for Sequence, Batch, Hidden sizes (due to tensor parallel support). + + inputs: + hiddens: accepts a tensor of shape [S x B x H] + + outputs: + z: a sample from Gaussian a tensor of shape [S x B x H] + z_mean: mean of Gaussian a tensor of shape [S x B x H] + z_logvar: log variance of Gaussian a tensor of shape [S x B x H] + z_log_prob: log probability of z over posterior log q(z|x) a tensor of shape [S x B x H] + """ + hiddens = inputs["hiddens"] + # compute distribution's parameters (or use cached ones) + if "z_mean" in inputs and "z_logvar" in inputs: + z_mean = inputs["z_mean"] + z_logvar = inputs["z_logvar"] + else: + # ColumnLinear returns output and bias, we ignore bias here (already added to hiddens) + z_mean, z_logvar = self.hiddens_to_mean_logvar(hiddens)[0].chunk(2, dim=-1) + # clamp logvar + z_logvar = z_logvar.clamp(min=self.min_logvar) + # sample z with reparametrization (or use cached one) + if "z" in inputs: + z = inputs["z"] + z_log_prob = inputs.get("z_log_prob", None) + else: + e = torch.randn_like(hiddens) + z = (z_logvar * 0.5).exp() * e + z_mean + z_log_prob = None + + if z_log_prob is None: + # compute log probability of z under a diagonal Gaussian distribution + z_log_prob = -0.5 * (math.log(2 * math.pi) + z_logvar + (z - z_mean).pow(2) / z_logvar.exp()) + # sum over the last dimension (hidden_size) + z_log_prob = z_log_prob.sum(dim=-1) + + return { + "z": z, # [S x B x H] + "z_mean": z_mean, # [S x B x H] + "z_logvar": z_logvar, # [S x B x H] + "z_log_prob": z_log_prob, # [S x B] + } diff --git a/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hiddens.py b/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hiddens.py new file mode 100644 index 000000000000..3e869a70f20f --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/transformations/megatron_hiddens.py @@ -0,0 +1,310 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +In order to register external hidden transforms and losses please use the following methods: +* register_hidden_loss(cls_name: str, class_path: str) +* register_hidden_transform(cls_name: str, class_path: str) + +See example config in: examples/nlp/language_modeling/conf/megatron_hiddens_base_config.yaml +""" + +import functools +import itertools +from typing import List + +import torch +from omegaconf.dictconfig import DictConfig +from omegaconf.omegaconf import OmegaConf + +from nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_loss import MegatronBaseHiddenLoss +from nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_transform import ( + MegatronBaseHiddenTransform, +) +from nemo.utils import logging +from nemo.utils.model_utils import import_class_by_path + +__all__ = ["MegatronHiddensModule"] + +# a registry of all hidden transforms (maps name to class path) +_LOSS_CLASS_REGISTRY = { + "a_mim": "nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_loss.MegatronAMIMHiddenLoss", + "vae": "nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_loss.MegatronVAEHiddenLoss", +} + +# a registry of all hidden losses (maps name to class path) +_TRANSFORM_CLASS_REGISTRY = { + "cond_gaussian": "nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_transform.MegatronGaussianHiddenTransform", +} + + +def get_registered_hiddens(): + """ + Return: + A dictionary with all registered hidden transforms and losses. + + Example: + { + "loss": ["a-mim", "vae"], + "transform": ["cond_gaussian"], + } + """ + return { + "loss": list(_LOSS_CLASS_REGISTRY.keys()), + "transform": list(_TRANSFORM_CLASS_REGISTRY.keys()), + } + + +def register_hidden_loss(cls_name: str, class_path: str): + """ + Register a hidden loss. + + + Args: + cls_name: name of the class + class_path: path to the class (e.g., "nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_transform.MegatronGaussianHiddenTransform") + """ + if cls_name in _LOSS_CLASS_REGISTRY: + raise ValueError(f"Cannot register duplicate hidden loss ({cls_name})") + _LOSS_CLASS_REGISTRY[cls_name] = class_path + logging.info(f"Registered hidden loss {cls_name} at {class_path}") + + +def register_hidden_transform(cls_name: str, class_path: str): + """ + Register a hidden transform. + + Args: + cls_name: name of the class + class_path: path to the class (e.g., "nemo.collections.nlp.modules.common.megatron.transformations.megatron_hidden_transform.MegatronGaussianHiddenTransform") + """ + if cls_name in _TRANSFORM_CLASS_REGISTRY: + raise ValueError(f"Cannot register duplicate hidden transform ({cls_name})") + _TRANSFORM_CLASS_REGISTRY[cls_name] = class_path + logging.info(f"Registered hidden transform {cls_name} at {class_path}") + + +def get_hiddens_module(cfg=None): + """Build a MegatronHiddensModule from a configuration cfg""" + # Build a hiddens module if config is provided. + if cfg is None: + return None + + logging.info(f"NOTE: Adding hiddens transforms and losses") + + # build all hidden transforms. We support a list or a dictionary of transforms (list enforces order) + transform_cfg = cfg.get("transform", []) + if isinstance(transform_cfg, (DictConfig, dict)): + transform_cfg = [transform_cfg] + hidden_transforms = [] + # here we expect transform_cfg to be a list of dictionaries + for cur_list_cfg in transform_cfg: + for name, cur_cfg in cur_list_cfg.items(): + cls_kwargs = OmegaConf.to_container(cur_cfg) + if not "cls_name" in cls_kwargs: + raise KeyError(f"Missing 'cls_name' in hidden transform {name}") + + cls_name = cls_kwargs.pop("cls_name") + # add name based on dictionary if not given in conf + if "name" not in cls_kwargs: + cls_kwargs["name"] = name + if cls_name not in _TRANSFORM_CLASS_REGISTRY: + raise KeyError(f"Unknown hidden transform {cls_name}, available: {_TRANSFORM_CLASS_REGISTRY.keys()}") + try: + cur_transform = import_class_by_path(_TRANSFORM_CLASS_REGISTRY[cls_name])(**cls_kwargs) + except Exception as e: + logging.error(f"Failed to build hidden transform {name} with cfg={cur_cfg}") + raise e + + hidden_transforms.append(cur_transform) + logging.info(f"Added transform {name} with cfg={cur_cfg}") + + # build all hidden losses + loss_cfg = cfg.get("loss", []) + if isinstance(loss_cfg, (DictConfig, dict)): + loss_cfg = [loss_cfg] + hidden_loss_transforms = [] + # here we expect loss_cfg to be a list of dictionaries + for cur_list_cfg in loss_cfg: + for name, cur_cfg in cur_list_cfg.items(): + cls_kwargs = OmegaConf.to_container(cur_cfg) + if not "cls_name" in cls_kwargs: + raise KeyError(f"Missing 'cls_name' in hidden loss {name}") + + cls_name = cls_kwargs.pop("cls_name") + # add name based on dictionary if not given in conf + if "name" not in cls_kwargs: + cls_kwargs["name"] = name + if cls_name not in _LOSS_CLASS_REGISTRY: + raise KeyError(f"Unknown hidden loss {cls_name}, available: {_LOSS_CLASS_REGISTRY.keys()}") + try: + cur_loss = import_class_by_path(_LOSS_CLASS_REGISTRY[cls_name])(**cls_kwargs) + except Exception as e: + logging.error(f"Failed to build hidden loss {name} with cfg={cur_cfg}") + raise e + hidden_loss_transforms.append(cur_loss) + logging.info(f"Added loss {name} with cfg={cur_cfg}") + + enc_output_name = cfg.get("enc_output_name", "hiddens") + + return MegatronHiddensModule( + hidden_transforms=hidden_transforms, + hidden_loss_transforms=hidden_loss_transforms, + enc_output_name=enc_output_name, + ) + + +class MegatronHiddensModule(torch.nn.Module): + """ + This class jointly handles the hidden transforms and hidden loss transforms. + It helps in validating, and applying the transforms. + """ + + def __init__( + self, + hidden_transforms: List[MegatronBaseHiddenLoss] = [], + hidden_loss_transforms: List[MegatronBaseHiddenTransform] = [], + enc_output_name: str = "hiddens", # name (key) of the encoder output + tokens_loss_weight: float = 1.0, # weight of the tokens loss + loss_prefix: str = "hiddens_", # if not None or "", add this prefix to all loss names + ): + super().__init__() + self.hidden_transforms = hidden_transforms + self.hidden_loss_transforms = hidden_loss_transforms + self.enc_output_name = enc_output_name + self.tokens_loss_weight = tokens_loss_weight + self.loss_prefix = loss_prefix + + # register all hidden / loss transforms as submodules to support learned parameters + if not all([isinstance(ht, MegatronBaseHiddenLoss) for ht in self.hidden_loss_transforms]): + raise TypeError( + f"hidden_loss_transforms should be a list of MegatronBaseHiddenLoss, but got {hidden_loss_transforms}" + ) + self.hidden_loss_transforms = torch.nn.ModuleList(self.hidden_loss_transforms) + if not all([isinstance(ht, MegatronBaseHiddenTransform) for ht in self.hidden_transforms]): + raise TypeError( + f"hidden_transforms should be a list of MegatronBaseHiddenTransform, but got {hidden_transforms}" + ) + self.hidden_transforms = torch.nn.ModuleList(self.hidden_transforms) + + # validate the inputs and outputs of all hidden transforms (make sure there are no duplicate output names) + duplicate_names = {} + # initialize with available outputs from hidden transforms with hiddens and mask as default + hidden_outputs = set(["hiddens", "hiddens_mask", "enc_output"]) + for ht in self.hidden_transforms: + # validate that all required inputs are available by order of hidden transforms + cur_input_names = set(ht.input_names) + if not cur_input_names.issubset(hidden_outputs): + raise ValueError( + f"Hidden transform {ht.name} requires inputs {cur_input_names - hidden_outputs} that are not available" + ) + + # collect all duplicate output names + cur_hidden_outputs = set(ht.output_names) + if not cur_hidden_outputs.isdisjoint(hidden_outputs): + duplicate_names[ht.name] = list(cur_hidden_outputs.intersection(hidden_outputs)) + + hidden_outputs.update(cur_hidden_outputs) + + # fail here reporting all duplicate output names + if duplicate_names: + raise ValueError( + f"Hidden transforms have duplicate outputs {{name: [duplicate outputs]}} = {duplicate_names}" + ) + + # validate that all loss transforms are supported by output of hidden transforms ("hiddens" is given by default) + loss_inputs = set(itertools.chain(*[lt.input_names for lt in self.hidden_loss_transforms])) + if not loss_inputs.issubset(hidden_outputs): + loss_inputs_dict = {lt.name: lt.input_names for lt in self.hidden_loss_transforms} + raise ValueError( + f"Loss transforms inputs = {loss_inputs - hidden_outputs} are not supported by hidden transforms with hidden_outputs = {hidden_outputs}, expected inputs per loss = {loss_inputs_dict}" + ) + + @functools.cached_property + def hidden_outputs(self): + """Get the hidden outputs from all the hidden transforms""" + all_output_names = [ht.output_names for ht in self.hidden_transforms] + [["hiddens", "hiddens_mask"]] + output_names = set().union(*all_output_names) + + return list(output_names) + + @functools.cached_property + def loss_inputs(self): + """Get the loss inputs from all the loss transforms""" + loss_inputs = set().union(*[lt.input_names for lt in self.hidden_loss_transforms]) + return list(loss_inputs) + + def apply_hidden_transforms(self, inputs, batch_data=None): + """ + Apply hidden transforms + Args: + inputs: a dictionary of inputs, with "hiddens" as the default key for hidden states + batch_data: a dictionary of batch data (e.g. "input_features"), optional + + Returns: + outputs: a dictionary of outputs, collecting + """ + outputs = inputs.copy() + for hidden_transform in self.hidden_transforms: + # make sure to collect all outputs from hidden transforms + outputs.update(hidden_transform.transform(outputs, batch_data=batch_data)) + + # update final encoder output + outputs["enc_output"] = outputs[self.enc_output_name] + + return outputs + + def apply_loss_transforms(self, outputs, batch_data=None): + """ + Apply loss transforms + Args: + outputs: a dictionary of outputs (after hidden transforms) + batch_data: a dictionary of batch data (e.g. "target_ids"), optional + + Returns: + loss_dict: a dictionary of all losses, + { + loss: joint loss (float), + _*: loss values from loss transforms, could be loss, or loss elements + } + """ + loss_dict = {} + joint_loss = 0.0 + for i, loss_transform in enumerate(self.hidden_loss_transforms): + cur_loss_dict = loss_transform.loss(outputs, batch_data=batch_data) + joint_loss = joint_loss + cur_loss_dict["weighted_loss"] + cur_loss_dict.pop("weighted_loss") + # add name to loss values + if loss_transform.name: + cur_loss_dict = {f"{loss_transform.name}_{k}": v for k, v in cur_loss_dict.items()} + + # check if cur_loss keys are unique - we do not allow to override keys + dup_keys = set(cur_loss_dict.keys()).intersection(set(loss_dict.keys())) + if len(dup_keys): + raise ValueError( + f"Loss transform ({i}) {loss_transform} is trying to override the following loss keys {list(dup_keys)}" + ) + # update loss dict + loss_dict.update(cur_loss_dict) + + # joint weighted loss (float) + loss_dict["loss"] = joint_loss + + # add prefix to all loss keys (default to 'hiddens_') + if self.loss_prefix: + loss_dict = {f"{self.loss_prefix}{k}": v for k, v in loss_dict.items()} + + # add tokens loss weight (to be used by caller, or be ignored) + loss_dict["tokens_loss_weight"] = torch.tensor(self.tokens_loss_weight).to(joint_loss) + + return loss_dict diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py index 7c7a428fa43f..045509d5adf9 100644 --- a/nemo/collections/nlp/modules/common/megatron/utils.py +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -383,8 +383,12 @@ def get_iterator_k_split(batch: List[torch.Tensor], num_microbatches: int) -> It microbatches = [dict(elem) for elem in microbatches] else: assert batch[0].shape[0] % num_microbatches == 0, "Issue with batch size configuration!" - split_batch = [torch.tensor_split(item, num_microbatches, dim=0) for item in batch] - microbatches = [[elem[i] for elem in split_batch] for i in range(num_microbatches)] + split_batch = [ + torch.tensor_split(item, num_microbatches, dim=0) if torch.is_tensor(item) else item for item in batch + ] + microbatches = [ + [elem[i] if elem is not None else elem for elem in split_batch] for i in range(num_microbatches) + ] return itertools.chain(microbatches)