Skip to content

Commit

Permalink
Add support for Numba FP16 RNNT Loss (NVIDIA#6991) (NVIDIA#7038)
Browse files Browse the repository at this point in the history
* Force working space memory to always be in fp32

Signed-off-by: smajumdar <[email protected]>

* Add support for fp16 testing in Numba

Signed-off-by: smajumdar <[email protected]>

* Add support for fp16 testing in Numba

Signed-off-by: smajumdar <[email protected]>

* Add support for fp16 testing in Numba

Signed-off-by: smajumdar <[email protected]>

* Fix cost calculation by upcasting to fp32

Signed-off-by: smajumdar <[email protected]>

* Fix cost calculation by upcasting to fp32

Signed-off-by: smajumdar <[email protected]>

* Add support to check if numba fp16 is available

Signed-off-by: smajumdar <[email protected]>

* add RNN-T loss implemented by PyTorch and test code (#5312)

* Fix the bugs in cache-aware streaming Conformer (#5032)

Signed-off-by: Vahid <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* IA3 support for GPT and T5 (#4909)

* init commit for ia3 adater training in GPT

Signed-off-by: arendu <[email protected]>

* ia3 adater training in GPT, models and adapter classes

Signed-off-by: arendu <[email protected]>

* reshape to operate even on non-contiguous tensors

Signed-off-by: arendu <[email protected]>

* configs

Signed-off-by: arendu <[email protected]>

* fixed none init

Signed-off-by: arendu <[email protected]>

* adding adapter and ia3 support for T5 based models

Signed-off-by: arendu <[email protected]>

* style fix

Signed-off-by: arendu <[email protected]>

* config update and t5 model adapter and ia3

Signed-off-by: arendu <[email protected]>

* removed unused imports

Signed-off-by: arendu <[email protected]>

* predict step for inference

Signed-off-by: arendu <[email protected]>

* style fix

Signed-off-by: arendu <[email protected]>

* style fix

Signed-off-by: arendu <[email protected]>

* adapter inference for t5

Signed-off-by: arendu <[email protected]>

* style fix

Signed-off-by: arendu <[email protected]>

* fixed bug micro and global batch size in eval

Signed-off-by: arendu <[email protected]>

* minor edit

Signed-off-by: arendu <[email protected]>

* agressive truncation if in test examples if no truncation field is given

Signed-off-by: arendu <[email protected]>

* corrected for language_model_path name changes in main

Signed-off-by: arendu <[email protected]>

* removed unused import

Signed-off-by: arendu <[email protected]>

* name change for language_model_path

Signed-off-by: arendu <[email protected]>

* include inter_attention to IA3

Signed-off-by: arendu <[email protected]>

* minor fix in confg

Signed-off-by: arendu <[email protected]>

* minor fixes

Signed-off-by: arendu <[email protected]>

* removed unused flag

Signed-off-by: arendu <[email protected]>

* addressing PR comments

Signed-off-by: arendu <[email protected]>

* address PR comments

Signed-off-by: arendu <[email protected]>

* minor fix

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* style fix

Signed-off-by: arendu <[email protected]>

* CI test

Signed-off-by: arendu <[email protected]>

* minor fix in jenkinsfile

Signed-off-by: arendu <[email protected]>

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Hainan Xu <[email protected]>

* Bug fix - Limit val batches set to 1.0  (#5023)

* Bug fix

Signed-off-by: shanmugamr1992 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Adressed sandeep's comments

* Fixing limit val batches support in bert

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixing limit val batches support in bert

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: shanmugamr1992 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sandeep Subramanian <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* [bug_fix] kv_channels is used when available (#5066)

* fix bug s.t kv_channels is used when available

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Hainan Xu <[email protected]>

* P&C Docs (#5068) (#5069)

Signed-off-by: Matvei Novikov <[email protected]>

Signed-off-by: Matvei Novikov <[email protected]>

Signed-off-by: Matvei Novikov <[email protected]>
Co-authored-by: Matvei Novikov <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Add spe_split_by_unicode_script arg (#5072)

* Add spe_split_by_unicode_script arg

Signed-off-by: Anas <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Anas <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Hainan Xu <[email protected]>

* probabilites -> probabilities (#5078) (#5079)

Signed-off-by: nithinraok <[email protected]>

Signed-off-by: nithinraok <[email protected]>

Signed-off-by: nithinraok <[email protected]>
Co-authored-by: Nithin Rao <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* increase PR and Issue sweep quantity and active close PRs. (#5073)

* increase PR and Issue sweep quantity and active close PRs.

Signed-off-by: Xuesong Yang <[email protected]>

* update with stricter rules, 30 days to be stale and 7 days to be closed for both Issues and PRs.

Signed-off-by: Xuesong Yang <[email protected]>

Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* [TTS] added missing German phoneme tokenizer. (#5070) (#5074)

Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* rename to match prompt leanring (#5076)

Signed-off-by: arendu <[email protected]>

Signed-off-by: arendu <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Missing fixes from r1.11.0 to T5 finetuning eval (#5054) (#5061)

* Fixes to seq2seq eval

Signed-off-by: MaximumEntropy <[email protected]>

* Style

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: Sandeep Subramanian <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Hainan Xu <[email protected]>

* Notebook bug fixes (#5084) (#5085)

* Notebook bug fixes

Signed-off-by: Virginia Adams <[email protected]>

* Turned nemo install back on

Signed-off-by: Virginia Adams <[email protected]>

* reverted notebook

Signed-off-by: Virginia Adams <[email protected]>

* Updated one line in entity linking nb

Signed-off-by: Virginia Adams <[email protected]>

Signed-off-by: Virginia Adams <[email protected]>
Co-authored-by: Eric Harper <[email protected]>

Signed-off-by: Virginia Adams <[email protected]>
Co-authored-by: Virginia Adams <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* update strategy in notebook from ddp_fork to dp (#5088) (#5089)

Co-authored-by: Zhilin Wang <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Fix bug in Squeezeformer Conv block (#5011) (#5024)

* Fix bug in Squeezeformer Conv block

Signed-off-by: smajumdar <[email protected]>

* Fix kernel context

Signed-off-by: smajumdar <[email protected]>

* Fix access mixin

Signed-off-by: smajumdar <[email protected]>

Signed-off-by: smajumdar <[email protected]>

Signed-off-by: smajumdar <[email protected]>
Co-authored-by: Somshubra Majumdar <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* fixed megatron lm conversion bug (PTL related) (#5038) (#5063)

Signed-off-by: David Mosallanezhad <[email protected]>

Signed-off-by: David Mosallanezhad <[email protected]>
Co-authored-by: David Mosallanezhad <[email protected]>

Signed-off-by: David Mosallanezhad <[email protected]>
Co-authored-by: David <[email protected]>
Co-authored-by: David Mosallanezhad <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Fix Unhashable type list for Numba Cuda spec augment kernel (#5093) (#5094)

Signed-off-by: smajumdar <[email protected]>

Signed-off-by: smajumdar <[email protected]>

Signed-off-by: smajumdar <[email protected]>
Co-authored-by: Somshubra Majumdar <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Fix numba (#5098)

Signed-off-by: smajumdar <[email protected]>

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Make it possible to specify output_filename in normalize_with_audio.py (#5092)

Signed-off-by: Elena Rastorgueva <[email protected]>

Signed-off-by: Elena Rastorgueva <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Greedy decoding confidence for CTC and RNNT (#4931)

* rnnt confidence draft

Signed-off-by: Aleksandr Laptev <[email protected]>

* word confidence

Signed-off-by: Aleksandr Laptev <[email protected]>

* advanced entropies added

Signed-off-by: Aleksandr Laptev <[email protected]>

* refactoring

Signed-off-by: Aleksandr Laptev <[email protected]>

* oops forgot a file

Signed-off-by: Aleksandr Laptev <[email protected]>

* metrics and benchmarking script added

Signed-off-by: Aleksandr Laptev <[email protected]>

* style fix

Signed-off-by: Aleksandr Laptev <[email protected]>

* texterrors installation added

Signed-off-by: Aleksandr Laptev <[email protected]>

* lgtm and bug fix

Signed-off-by: Aleksandr Laptev <[email protected]>

* fix comments

Signed-off-by: Aleksandr Laptev <[email protected]>

* fix typos

Signed-off-by: Aleksandr Laptev <[email protected]>

* add missing import after rebase

Signed-off-by: Aleksandr Laptev <[email protected]>

Signed-off-by: Aleksandr Laptev <[email protected]>
Co-authored-by: Aleksandr Laptev <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* [Add] SLURP models and examples (#4668)

* add model, util and loss

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* refactor

Signed-off-by: stevehuang52 <[email protected]>

* refactor annd update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update and refactor

Signed-off-by: stevehuang52 <[email protected]>

* update and refactor

Signed-off-by: stevehuang52 <[email protected]>

* update and refactor

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* update docs

Signed-off-by: stevehuang52 <[email protected]>

* update available models

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

* refactor data processing

Signed-off-by: stevehuang52 <[email protected]>

* fix typo

Signed-off-by: stevehuang52 <[email protected]>

* update docs

Signed-off-by: stevehuang52 <[email protected]>

* refactor and update

Signed-off-by: stevehuang52 <[email protected]>

* update doc

Signed-off-by: stevehuang52 <[email protected]>

* move transformer to asr.modules

Signed-off-by: stevehuang52 <[email protected]>

* move transformer to asr.modules

Signed-off-by: stevehuang52 <[email protected]>

* get rid of jsonlines

Signed-off-by: stevehuang52 <[email protected]>

* refactor

Signed-off-by: stevehuang52 <[email protected]>

* revert changes to nlp

Signed-off-by: stevehuang52 <[email protected]>

Signed-off-by: stevehuang52 <[email protected]>
Signed-off-by: He Huang (Steve) <[email protected]>
Co-authored-by: Jagadeesh Balam <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* only optimize params that are part of the adapter modules (#5086)

Signed-off-by: arendu <[email protected]>

Signed-off-by: arendu <[email protected]>
Co-authored-by: Virginia Adams <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Pipeline Parallel T5 Prompt Learning (#4956)

* Added pre process flag checks and pipeline parallel in fwd

Signed-off-by: Virginia Adams <[email protected]>

* Added rank check for pipeline parallel

Signed-off-by: Virginia Adams <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* T5 prompt learning works!

Signed-off-by: Virginia Adams <[email protected]>

* IA3 passing CI

Signed-off-by: Virginia Adams <[email protected]>

* Fixed typo

Signed-off-by: Virginia Adams <[email protected]>

* removed optimizer setup so Adi's change will not conflict

Signed-off-by: Virginia Adams <[email protected]>

Signed-off-by: Virginia Adams <[email protected]>
Signed-off-by: Adi Renduchintala <[email protected]>
Co-authored-by: Adi Renduchintala <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Hainan Xu <[email protected]>

* [TTS] remove phonemizer.py (#5090)

remove phonemizer.py and convert code block to markdown in the tutorial.

Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* T5 Decoding with PP > 2 fix (#5091) (#5103)

* set sequence lenghts in the pipeline properly

Signed-off-by: MaximumEntropy <[email protected]>

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

Signed-off-by: MaximumEntropy <[email protected]>

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: Sandeep Subramanian <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* [TTS] fixed wrong val loss for epoch 0 and inconsistent metrics names (#5087) (#5102)

* fixed hifigan configs as well
* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Xuesong Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Hainan Xu <[email protected]>

* Fix and refactor consumed samples save/restore for Megatron models. (#5077)

* Fixes and refactor

Signed-off-by: MaximumEntropy <[email protected]>

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

* Remove unused imports

Signed-off-by: MaximumEntropy <[email protected]>

* Empty

Signed-off-by: MaximumEntropy <[email protected]>

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

Signed-off-by: MaximumEntropy <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* RIR corpus generator tool (#4927)

Signed-off-by: Ante Jukić <[email protected]>

Signed-off-by: Ante Jukić <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Multiprocessing fix (#5106) (#5107)

Signed-off-by: Matvei Novikov <[email protected]>

Signed-off-by: Matvei Novikov <[email protected]>

Signed-off-by: Matvei Novikov <[email protected]>
Co-authored-by: Matvei Novikov <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* [Bug fix] PC lexical + audio (#5109) (#5110)

* training running

Signed-off-by: ekmb <[email protected]>

* revert

Signed-off-by: ekmb <[email protected]>

* revert

Signed-off-by: ekmb <[email protected]>

Signed-off-by: ekmb <[email protected]>

Signed-off-by: ekmb <[email protected]>
Co-authored-by: Evelina <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* [Fix] schedulers with no max_steps param (#4564)

* fix schedulers

Signed-off-by: stevehuang52 <[email protected]>

* update to use python inspect module

Signed-off-by: stevehuang52 <[email protected]>

* update

Signed-off-by: stevehuang52 <[email protected]>

Signed-off-by: stevehuang52 <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* T5 prompt learning fixes missing from r.11.0 merge (#5075) (#5101)

* Fix special tokens

Signed-off-by: MaximumEntropy <[email protected]>

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

* Empty

Signed-off-by: MaximumEntropy <[email protected]>

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: David <[email protected]>

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: Sandeep Subramanian <[email protected]>
Co-authored-by: David <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* [TTS] Add NeMo TTS Primer Tutorial (#4933)

* [TTS] Add NeMo TTS Primer Tutorial

Signed-off-by: Ryan <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Add Squeezeformer CTC model checkpoints on Librispeech (#5121)

Signed-off-by: smajumdar <[email protected]>

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* adding loss normalization options to rnnt joint  (#4829)

* adding normalization options to rnnt joint loss

* moving the param to joint

* moving loss normalization to rnnt loss config

* style

* cleaning up

* fixing sum reduction in joint

Signed-off-by: Dima Rekesh <[email protected]>

* moving reduction into RNNT loss class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactoring

* typos

Signed-off-by: Dima Rekesh <[email protected]>

Signed-off-by: Dima Rekesh <[email protected]>
Co-authored-by: Dima Rekesh <[email protected]>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Hainan Xu <[email protected]>

* Asr concat dataloader (#5108)

* forced precision

* typo

* initial commit

Signed-off-by: Dima Rekesh <[email protected]>

* typos and bugs

Signed-off-by: Dima Rekesh <[email protected]>

* reverting conformer encoder

Signed-off-by: Dima Rekesh <[email protected]>

* additional checks

Signed-off-by: Dima Rekesh <[email protected]>

* adding support to CTC models as well

* reverting conformer_encoder

Signed-off-by: Dima Rekesh <[email protected]>

* typo

Signed-off-by: Dima Rekesh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactoring

Signed-off-by: Dima Rekesh <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactoring

Signed-off-by: Dima Rekesh <[email protected]>

* merging

Signed-off-by: Dima Rekesh <[email protected]>

Signed-off-by: Dima Rekesh <[email protected]>
Signed-off-by: Dima Rekesh <[email protected]>
Co-authored-by: Dima Rekesh <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Somshubra Majumdar <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* fix blossom ci unittests

Signed-off-by: Oleksii Kuchaiev <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* bugfix: pybtex.database.InvalidNameString: Too many commas in author field. (#5112) (#5115)

Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Uppdate container version to 22.09 (#5105)

* update container version

Signed-off-by: ericharper <[email protected]>

* pin click

Signed-off-by: ericharper <[email protected]>

* pin click 8.0.2

Signed-off-by: ericharper <[email protected]>

Signed-off-by: ericharper <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Remove unsupported arguments from MegatronNMT (#5065)

* Fixes

Signed-off-by: MaximumEntropy <[email protected]>

* Fixes

Signed-off-by: MaximumEntropy <[email protected]>

* Style

Signed-off-by: MaximumEntropy <[email protected]>

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

* More fixes

Signed-off-by: MaximumEntropy <[email protected]>

Signed-off-by: MaximumEntropy <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* pp2 support for T5 IA3 learning and T5 Adapters learning (#5116)

* enabling pp2

Signed-off-by: arendu <[email protected]>

* optimizer update

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* T5 pp>1 support for adapters and ia3

Signed-off-by: arendu <[email protected]>

* fix bug with missing adapter_tuning

Signed-off-by: arendu <[email protected]>

* inference error fixed, pp=2

Signed-off-by: arendu <[email protected]>

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* T5 Prompt Learning Fixes for Pipeline Parallel (#5120)

* Initial fixes

Signed-off-by: MaximumEntropy <[email protected]>

* Added back validation acc

Signed-off-by: Virginia Adams <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Put num workers back

Signed-off-by: Virginia Adams <[email protected]>

* added relative encoding if statament

Signed-off-by: Virginia Adams <[email protected]>

* Added back val loss only validation

Signed-off-by: Virginia Adams <[email protected]>

* Revert "Added back val loss only validation"

This reverts commit 86d8f4806fe30335c40c3716ce18259939df500f.

* Removed val acc for PP > 1

Signed-off-by: Virginia Adams <[email protected]>

* Removed enc_seq_len if statement

Signed-off-by: Virginia Adams <[email protected]>

* Added back validation acc calc

Signed-off-by: Virginia Adams <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: MaximumEntropy <[email protected]>
Signed-off-by: Virginia Adams <[email protected]>
Signed-off-by: Virginia Adams <[email protected]>
Co-authored-by: Virginia Adams <[email protected]>
Co-authored-by: Virginia Adams <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Virginia Adams <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* add doc info (#4721)

Signed-off-by: Yang Zhang <[email protected]>

Signed-off-by: Yang Zhang <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* [TTS] Add SpanishCharsTokenizer (#5135)

* [TTS] Add SpanishCharsTokenizer

Signed-off-by: Ryan <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Update megatron interface to dialogue (#4936)

* fix style formatting

Signed-off-by: Zhilin Wang <[email protected]>

* update template to include description of intent

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkinsfile

Signed-off-by: Zhilin Wang <[email protected]>

* changes based on requests in review

Signed-off-by: Zhilin Wang <[email protected]>

* add compatibility with assistant dataset

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkins

Signed-off-by: Zhilin Wang <[email protected]>

* remove dialogue_state_tracking

Signed-off-by: Zhilin Wang <[email protected]>

* update huggingface utils for dialogue

Signed-off-by: Zhilin Wang <[email protected]>

* rename dialogue_state_tracking_hybrid to dialogue_state_tracking_sgdqa

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* fix style

Signed-off-by: Zhilin Wang <[email protected]>

* style fix nemo/collections/nlp/models/dialogue_state_tracking_sgdqa/__init__.py

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkinsfile for SGDGEN

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkinsfile for SGDGEN

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkinsfile for SGDGEN

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkinsfile for SGDGEN

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkinsfile for SGDGEN

Signed-off-by: Zhilin Wang <[email protected]>

* fix typo

Signed-off-by: Zhilin Wang <[email protected]>

* add docstrings for assistant data processsor

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkins for SGDGEN local checkpoint

Signed-off-by: Zhilin Wang <[email protected]>

* update style

Signed-off-by: Zhilin Wang <[email protected]>

* use local vocab file for Jenkinsfile

Signed-off-by: Zhilin Wang <[email protected]>

* patch for Jenkins CI using local file

Signed-off-by: Zhilin Wang <[email protected]>

* add slot filling prediction and metrics

Signed-off-by: Zhilin Wang <[email protected]>

* remove unused code

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* refactor metrics code out of Dialogue GPT Model

Signed-off-by: Zhilin Wang <[email protected]>

* integrate backward compatible support for IntentSlotClassificationModel (bert model)

Signed-off-by: Zhilin Wang <[email protected]>

* save prediction file for IntentSlotClassification

Signed-off-by: Zhilin Wang <[email protected]>

* update dialogue gpt model training for megatron gpt

Signed-off-by: Zhilin Wang <[email protected]>

* remove batch generate for HF GPT2, which causes lower performance

Signed-off-by: Zhilin Wang <[email protected]>

* add few shot capability to dialogue gpt model

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkinsfile and remove unused import

Signed-off-by: Zhilin Wang <[email protected]>

* update code description and clarity

Signed-off-by: Zhilin Wang <[email protected]>

* address PR comments

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* integrate compatibility with ZeroShotIntentModel

Signed-off-by: Zhilin Wang <[email protected]>

* rename folder to dialogue due to increased scope and further refactor for clarity

Signed-off-by: Zhilin Wang <[email protected]>

* added dialogue GPT for sequence generation task (e.g. answer extender)

Signed-off-by: Zhilin Wang <[email protected]>

* add CI test for DialogueGPTGenerationModel

Signed-off-by: Zhilin Wang <[email protected]>

* integrate DialogueS2SGenerationModel for generation task (e.g. answer extender)

Signed-off-by: Zhilin Wang <[email protected]>

* modify huggingface utils to support HF t5/BART models

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* remove unused imports

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkinsfile

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkinsfile

Signed-off-by: Zhilin Wang <[email protected]>

* update bleu metric

Signed-off-by: Zhilin Wang <[email protected]>

* fix bleu metric style

Signed-off-by: Zhilin Wang <[email protected]>

* debug bleu metric

Signed-off-by: Zhilin Wang <[email protected]>

* debug bleu metric

Signed-off-by: Zhilin Wang <[email protected]>

* update based on PR #3893

Signed-off-by: Zhilin Wang <[email protected]>

* update 2 based on PR #3893

Signed-off-by: Zhilin Wang <[email protected]>

* update 3 based on PR #3893

Signed-off-by: Zhilin Wang <[email protected]>

* integrate sgd generation based on user user utterance and system slot-values to generate system utterance

Signed-off-by: Zhilin Wang <[email protected]>

* add validation model saving capabilities

Signed-off-by: Zhilin Wang <[email protected]>

* cleaned up code for SGD Based Answer extender

Signed-off-by: Zhilin Wang <[email protected]>

* update Dialogue Generation CI

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkinsfile

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkinsfile

Signed-off-by: Zhilin Wang <[email protected]>

* fix Jenkins CI issue"

Signed-off-by: Zhilin Wang <[email protected]>

* add support for design dataset

Signed-off-by: Zhilin Wang <[email protected]>

* remove unnecessary imports

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkins

Signed-off-by: Zhilin Wang <[email protected]>

* update jenkins

Signed-off-by: Zhilin Wang <[email protected]>

* update jenkins

Signed-off-by: Zhilin Wang <[email protected]>

* support megatron for dialogue_s2s_generation_model

Signed-off-by: Zhilin Wang <[email protected]>

* reduce loaded samples in MSMarcoDataProcessor to 64 when cfg.model.dataset.debug_mode=True

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* update CI

Signed-off-by: Zhilin Wang <[email protected]>

* update checkpoint and predictions filename to include epoch number

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* integrate HF BART MNLI into zero shot intent model

Signed-off-by: Zhilin Wang <[email protected]>

* integrate Dialogue Nearest Neighbour Model

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkins

Signed-off-by: Zhilin Wang <[email protected]>

* update Jenkins

Signed-off-by: Zhilin Wang <[email protected]>

* refactor Dialogue SGD Data Processor to make interface for models cleaner

Signed-off-by: Zhilin Wang <[email protected]>

* update jenkins

Signed-off-by: Zhilin Wang <[email protected]>

* update Dialogue S2S Generation model for DialogueSGDDataProcessor interface

Signed-off-by: Zhilin Wang <[email protected]>

* update jenkins

Signed-off-by: Zhilin Wang <[email protected]>

* update jenkins

Signed-off-by: Zhilin Wang <[email protected]>

* support sgd and drive thru datasets by zero shot model and nearest neighbour model

Signed-off-by: Zhilin Wang <[email protected]>

* add prediction saving code to nearest neighbour and zero shot intent models

Signed-off-by: Zhilin Wang <[email protected]>

* fix typo in sgd data processor

Signed-off-by: Zhilin Wang <[email protected]>

* integrate Dialogue Mellon QA Data Processor

Signed-off-by: Zhilin Wang <[email protected]>

* update mellon qa

Signed-off-by: Zhilin Wang <[email protected]>

* update dialogue.py to remove outdated info

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* update dialogue_config.yaml

Signed-off-by: Zhilin Wang <[email protected]>

* update dialogue_config.yaml

Signed-off-by: Zhilin Wang <[email protected]>

* add dialogue docs

Signed-off-by: Zhilin Wang <[email protected]>

* address review comments

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* style fix for cfg

Signed-off-by: Zhilin Wang <[email protected]>

* make dependency on apex optional

Signed-off-by: Zhilin Wang <[email protected]>

* change NLPDDPluggin calling logic to make it possible to run without apex

Signed-off-by: Zhilin Wang <[email protected]>

* add first draft of tutorial

Signed-off-by: Zhilin Wang <[email protected]>

* reduce ms marco size by removing lines without wellFormedAnswers

Signed-off-by: Zhilin Wang <[email protected]>

* address pr comments

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* update colab tutorial link in dialogue docs

Signed-off-by: Zhilin Wang <[email protected]>

* include unit test and some refactor to facilitate unit test

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* address pr issues

Signed-off-by: Zhilin Wang <[email protected]>

* remove typos in dialogue tutorial

Signed-off-by: Zhilin Wang <[email protected]>

* support larger files for question answering

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* remove unnecessary artifacts to reduce memory use

Signed-off-by: Zhilin Wang <[email protected]>

* put 0 tensor to device

Signed-off-by: Zhilin Wang <[email protected]>

* update link within dialogue tutorial

Signed-off-by: Zhilin Wang <[email protected]>

* restore previously delete files

Signed-off-by: Zhilin Wang <[email protected]>

* update error handling when loss = nan

Signed-off-by: Zhilin Wang <[email protected]>

* update nan handling

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* update spanning loss func

Signed-off-by: Zhilin Wang <[email protected]>

* update spanning loss

Signed-off-by: Zhilin Wang <[email protected]>

* fix type error raised in qa_dataset.py

Signed-off-by: Zhilin Wang <[email protected]>

* add error checking message

Signed-off-by: Zhilin Wang <[email protected]>

* revert back to float32

Signed-off-by: Zhilin Wang <[email protected]>

* revert back to float32

Signed-off-by: Zhilin Wang <[email protected]>

* update error msgs

Signed-off-by: Zhilin Wang <[email protected]>

* update error msgs

Signed-off-by: Zhilin Wang <[email protected]>

* update error msgs

Signed-off-by: Zhilin Wang <[email protected]>

* update error msgs

Signed-off-by: Zhilin Wang <[email protected]>

* update error msgs

Signed-off-by: Zhilin Wang <[email protected]>

* update error msgs

Signed-off-by: Zhilin Wang <[email protected]>

* update error msgs

Signed-off-by: Zhilin Wang <[email protected]>

* update error msgs

Signed-off-by: Zhilin Wang <[email protected]>

* update exp logging

Signed-off-by: Zhilin Wang <[email protected]>

* update error msgs

Signed-off-by: Zhilin Wang <[email protected]>

* update loading of large file from pickle to json

Signed-off-by: Zhilin Wang <[email protected]>

* update loading of large file from pickle to json

Signed-off-by: Zhilin Wang <[email protected]>

* limit number of negative samples

Signed-off-by: Zhilin Wang <[email protected]>

* revert post processing

Signed-off-by: Zhilin Wang <[email protected]>

* revert post processing

Signed-off-by: Zhilin Wang <[email protected]>

* remove unused methods and style fix

Signed-off-by: Zhilin Wang <[email protected]>

* add more documentation

Signed-off-by: Zhilin Wang <[email protected]>

* remove unused imports

Signed-off-by: Zhilin Wang <[email protected]>

* changes base on PR review

Signed-off-by: Zhilin Wang <[email protected]>

* set wandb logger falseby default

Signed-off-by: Zhilin Wang <[email protected]>

* update interface with megatron gpt prompt learning

Signed-off-by: Zhilin Wang <[email protected]>

* update inline documentation

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* update prompt_ids

Signed-off-by: Zhilin Wang <[email protected]>

* update error msg

Signed-off-by: Zhilin Wang <[email protected]>

* update config

Signed-off-by: Zhilin Wang <[email protected]>

* update config

Signed-off-by: Zhilin Wang <[email protected]>

* set inference = False for dialgue prompt learning during trainng

Signed-off-by: Zhilin Wang <[email protected]>

* set inference = False for dialgue prompt learning during trainng

Signed-off-by: Zhilin Wang <[email protected]>

* remove unused code

Signed-off-by: Zhilin Wang <[email protected]>

* update config yaml

Signed-off-by: Zhilin Wang <[email protected]>

* fix bug for megatron gpt prompt learning

Signed-off-by: Zhilin Wang <[email protected]>

* remove unused import

Signed-off-by: Zhilin Wang <[email protected]>

* address comments in PR

Signed-off-by: Zhilin Wang <[email protected]>

* address comments in PR

Signed-off-by: Zhilin Wang <[email protected]>

* address typo

Signed-off-by: Zhilin Wang <[email protected]>

* add megatron t5 inference

Signed-off-by: Zhilin Wang <[email protected]>

* fix bug due to bert tokenizer not being space-aware

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* update style

Signed-off-by: Zhilin Wang <[email protected]>

* update IntentSlotModel onnx export test

Signed-off-by: Zhilin Wang <[email protected]>

* update style

Signed-off-by: Zhilin Wang <[email protected]>

* update exportable

Signed-off-by: Zhilin Wang <[email protected]>

* address PR comments

Signed-off-by: Zhilin Wang <[email protected]>

* replace functools.cache_property with functools.lru_cache to maintain python 3.7 compatibility

Signed-off-by: Zhilin Wang <[email protected]>

* improve speed of rank_candidates and support for p tuning

Signed-off-by: Zhilin Wang <[email protected]>

* update dialogue.py

Signed-off-by: Zhilin Wang <[email protected]>

* fix megatron prompt learning saving bug

Signed-off-by: Zhilin Wang <[email protected]>

* update generate_candidate method

Signed-off-by: Zhilin Wang <[email protected]>

* remove repeated init text ids and invert attention masks

Signed-off-by: Zhilin Wang <[email protected]>

* update typo

Signed-off-by: Zhilin Wang <[email protected]>

* custom collate fn to remove excess padding in batch

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* style fix

Signed-off-by: Zhilin Wang <[email protected]>

* update complete method to mitigate issue when max seq len is low

Signed-off-by: Zhilin Wang <[email protected]>

* address pr comments

Signed-off-by: Zhilin Wang <[email protected]>

* update generation interface

Signed-off-by: Zhilin Wang <[email protected]>

Signed-off-by: Zhilin Wang <[email protected]>
Co-authored-by: Zhilin Wang <[email protected]>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
Co-authored-by: Yang Zhang <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Co-authored-by: Sandeep Subramanian <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Added save inference ready .nemo file with every checkpoint (#5055)

* Added save inference ready .nemo file with every checkpoint

Signed-off-by: Virginia Adams <[email protected]>

* Python style fix

Signed-off-by: Virginia Adams <[email protected]>

* addressed Adi's comment

Signed-off-by: Virginia Adams <[email protected]>

* Added ptuning check in model checkpoint saving

Signed-off-by: Virginia Adams <[email protected]>

* Changed save_nemo_on_valdaition default to False

Signed-off-by: Virginia Adams <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Changes global batch size of adapter CI

Signed-off-by: Virginia Adams <[email protected]>

* Changed num workers to 0

Signed-off-by: Virginia Adams <[email protected]>

* added first stage of pipeline check

Signed-off-by: Virginia Adams <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Virginia Adams <[email protected]>
Signed-off-by: Virginia Adams <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Hainan Xu <[email protected]>

* Fixes for docs/typos + remove max_utts parameter from tarred datasets as it causes hang in training (#5118)

* Remove ; from jupyter notebook cells

Signed-off-by: Igor Gitman <[email protected]>

* Fix typos in documentation/code

Signed-off-by: Igor Gitman <[email protected]>

* Fix output message to have 'or equal'

Signed-off-by: Igor Gitman <[email protected]>

* Link formatting fixes

Signed-off-by: Igor Gitman <[email protected]>

* Add error if max_utts is used in tarred datasets

Signed-off-by: Igor Gitman <[email protected]>

* Remove max_utts parameter from tarred datasets

Signed-off-by: Igor Gitman <[email protected]>

* Fix max_utts removal in tests

Signed-off-by: Igor Gitman <[email protected]>

* Fix typo if -> is

Signed-off-by: Igor Gitman <[email protected]>

Signed-off-by: Igor Gitman <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Merge r1.12.0 main (#5139)

* update branch

Signed-off-by: ericharper <[email protected]>

* Add cherry-pick action (#4958)

* add cherry-pick action

Signed-off-by: ericharper <[email protected]>

* Pin Transformers version to fix CI (#4955)

* Pin transformers version in CI to prevent offline tokenizer loading error

Signed-off-by: SeanNaren <[email protected]>

* Drop version

Signed-off-by: SeanNaren <[email protected]>

* Disable offline temporarily

Signed-off-by: SeanNaren <[email protected]>

* Disable offline temporarily

Signed-off-by: SeanNaren <[email protected]>

* Enable offline

Signed-off-by: SeanNaren <[email protected]>

Signed-off-by: SeanNaren <[email protected]>

Signed-off-by: ericharper <[email protected]>
Signed-off-by: SeanNaren <[email protected]>
Co-authored-by: Sean Naren <[email protected]>

* upper bound transformers

Signed-off-by: ericharper <[email protected]>

* remove duplicate transformers requirement

Signed-off-by: ericharper <[email protected]>

* Release SOTA Lang ID model  (#5080)

* add pretrained lang id model ambernet

Signed-off-by: fayejf <[email protected]>

* update doc and style fix

Signed-off-by: fayejf <[email protected]>

Signed-off-by: fayejf <[email protected]>

* update branch and package info

Signed-off-by: ericharper <[email protected]>

* remove upper bounds on lightning and transformers

Signed-off-by: ericharper <[email protected]>

* remove transformers offline from ci

Signed-off-by: ericharper <[email protected]>

* upper bound transformers

Signed-off-by: ericharper <[email protected]>

Signed-off-by: ericharper <[email protected]>
Signed-off-by: SeanNaren <[email protected]>
Signed-off-by: fayejf <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
Co-authored-by: fayejf <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Added ASR model comparison to SDE (#5043)

SDE: Added ASR model comparison tool to SDE
transcribe speech: Added support for many predictions in one file, as well as custom field names
Signed-off-by: George Zelenfroynd <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* fix nmt eval sampler (#5154)

Signed-off-by: Abhinav Khattar <[email protected]>

Signed-off-by: Abhinav Khattar <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Fix Global init steps (#5143)

* move global step to base

Signed-off-by: Yi Dong <[email protected]>

* fix fused softmax

Signed-off-by: Yi Dong <[email protected]>

* add the missing file

Signed-off-by: Yi Dong <[email protected]>

* update the fused kernel

Signed-off-by: Yi Dong <[email protected]>

* fix import error

Signed-off-by: Yi Dong <[email protected]>

* fix import again

Signed-off-by: Yi Dong <[email protected]>

Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: Yi Dong <[email protected]>
Co-authored-by: Yi Dong <[email protected]>
Co-authored-by: Sandeep Subramanian <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* [TTS] bug fix - sample rate was being ignored in vocoder dataset (#4518)

* bug fix - sample rate was being ignored in vocoder dataset when not loading mel
* handled n segments for a different sampling rate than original sampling rate
* Added case for n_segments 0, warning for n_segments greater than file length

Signed-off-by: Paarth Neekhara <[email protected]>
Co-authored-by: Xuesong Yang <[email protected]>
Co-authored-by: Jocelyn <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Add EMA support to NeMo (#4764)

* Added Base files

Signed-off-by: SeanNaren <[email protected]>

* Some refactors, swap to using MNIST Lnet

Signed-off-by: SeanNaren <[email protected]>

* Add a few more tests, allow the callback to be set via the exp manager

Signed-off-by: SeanNaren <[email protected]>

* Actually run validation for testing

Signed-off-by: SeanNaren <[email protected]>

* Run isort

Signed-off-by: SeanNaren <[email protected]>

* Add test for saving state/fix saving state

Signed-off-by: SeanNaren <[email protected]>

* Use dummy model

Signed-off-by: SeanNaren <[email protected]>

* Fix test

Signed-off-by: SeanNaren <[email protected]>

* Add copyright

Signed-off-by: SeanNaren <[email protected]>

* Support saving separate EMA weight module

Signed-off-by: SeanNaren <[email protected]>

* Add standalone functionality/logging

Signed-off-by: SeanNaren <[email protected]>

* Expose more parameters

Signed-off-by: SeanNaren <[email protected]>

* Modify to allow option to replace validation

Signed-off-by: SeanNaren <[email protected]>

* Add jenkins test, formatting

Signed-off-by: SeanNaren <[email protected]>

* Pin Transformers version to fix CI (#4955)

* Pin transformers version in CI to prevent offline tokenizer loading error

Signed-off-by: SeanNaren <[email protected]>

* Drop version

Signed-off-by: SeanNaren <[email protected]>

* Disable offline temporarily

Signed-off-by: SeanNaren <[email protected]>

* Disable offline temporarily

Signed-off-by: SeanNaren <[email protected]>

* Enable offline

Signed-off-by: SeanNaren <[email protected]>

Signed-off-by: SeanNaren <[email protected]>

* Add cherry-pick action (#4958) (#4961)

* add cherry-pick action

Signed-off-by: ericharper <[email protected]>

* Pin Transformers version to fix CI (#4955)

* Pin transformers version in CI to prevent offline tokenizer loading error

Signed-off-by: SeanNaren <[email protected]>

* Drop version

Signed-off-by: SeanNaren <[email protected]>

* Disable offline temporarily

Signed-off-by: SeanNaren <[email protected]>

* Disable offline temporarily

Signed-off-by: SeanNaren <[email protected]>

* Enable offline

Signed-off-by: SeanNaren <[email protected]>

Signed-off-by: SeanNaren <[email protected]>

Signed-off-by: ericharper <[email protected]>
Signed-off-by: SeanNaren <[email protected]>
Co-authored-by: Sean Naren <[email protected]>

Signed-off-by: ericharper <[email protected]>
Signed-off-by: SeanNaren <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
Signed-off-by: SeanNaren <[email protected]>

* Fix changelog builder (#4962) (#4963)

Signed-off-by: smajumdar <[email protected]>

Signed-off-by: smajumdar <[email protected]>

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: SeanNaren <[email protected]>

* fix cherry pick workflow (#4964) (#4965)

Signed-off-by: ericharper <[email protected]>

Signed-off-by: ericharper <[email protected]>

Signed-off-by: ericharper <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: SeanNaren <[email protected]>

* reorder model check (#4959) (#4967)

Signed-off-by: nithinraok <[email protected]>

Signed-off-by: nithinraok <[email protected]>

Signed-off-by: nithinraok <[email protected]>
Co-authored-by: Nithin Rao <[email protected]>
Signed-off-by: SeanNaren <[email protected]>

* check for active conda environment (#4970) (#4971)

Signed-off-by: SeanNaren <[email protected]>

* [TTS] fix broken tutorial for MixerTTS. (#4949) (#4976)

Signed-off-by: Xuesong Yang <[email protected]>

Signed-off-by: Xuesong Yang <[email protected]>

Signed-off-by: Xuesong Yang <[email protected]>
Co-authored-by: Xuesong Yang <[email protected]>
Signed-off-by: SeanNaren <[email protected]>

* Checkpoint averaging class fix (#4946)

* 1. Added args.class_path to provide it externally.

Signed-off-by: Micha Livne <[email protected]>

* 1. Fixed style.

Signed-off-by: Micha Livne <[email protected]>

Signed-off-by: Micha Livne <[email protected]>
Signed-off-by: SeanNaren <[email protected]>

* Add ability to give seperate datasets for test, train and validation (#4798)

* Add ability to give seperate datasets for test, train and validation

* Addressed Sandeeps comments

* Addressed Sandeeps comments

* Add ability to give seperate datasets for test, train and validation

* Add ability to give seperate datasets for test, train and validation

* Addressed review comments

* Bug fix for common dataset utils

* Add CI tests

Signed-off-by: shanmugamr1992 <[email protected]>

* Reformat code

Signed-off-by: shanmugamr1992 <[email protected]>

* Bug fix

Signed-off-by: shanmugamr1992 <[email protected]>

* Bug fix

* Bug Fix

* Bug Fix

* Update Jenkinsfile

* Addressed comments

* Addressed Eriks comments.

* Addressed Sandeep

* Update Jenkinsfile

* Update Jenkinsfile

* Update dataset_utils.py

* Update Jenkinsfile

* Update Jenkinsfile

* Use GPT CI config

Signed-off-by: MaximumEntropy <[email protected]>

Signed-off-by: shanmugamr1992 <[email protected]>
Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: MaximumEntropy <[email protected]>
Signed-off-by: SeanNaren <[email protected]>

* fix label models restoring issue from wrighted cross entropy (#4968) (#4975)

Signed-off-by: nithinraok <[email protected]>

Signed-off-by: nithinraok <[email protected]>

Signed-off-by: nithinraok <[email protected]>
Co-authored-by: Nithin Rao <[email protected]>
Signed-off-by: SeanNaren <[email protected]>

* Add simple pre-commit file (#4983)

* Add simple pre-commit file

Signed-off-by: SeanNaren <[email protected]>

* Exclude docs folder

Signed-off-by: SeanNaren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: SeanNaren <[email protected]>

* Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"

This reverts commit 053bd5ba579537a5f311b431871c21f3381b43eb.

Signed-off-by: SeanNaren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: SeanNaren <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: SeanNaren <[email protected]>

* Import pycuda.autoprimaryctx or pycuda.autoinit to init pycuda execution environment (#4951)

Signed-off-by: Jin Li <[email protected]>

Signed-off-by: Jin Li <[email protected]>
Co-authored-by: Somshubra Majumdar <[email protected]>
Signed-off-by: SeanNaren <[email protected]>

* Adding speaker embedding conditioning in fastpitch (#4986)

Signed-off-by: subhankar-ghosh <[email protected]>

Signed-off-by: subhankar-ghosh <[email protected]>
Signed-off-by: SeanNaren <[email protected]>

* Fix ASR issues (#4984) (#4991)

* Fix ASR issues

Signed-off-by: smajumdar <[email protected]>

* Revert fix

Signed-off-by: smajumdar <[email protected]>

Signed-off-by: smajumdar <[email protected]>

Signed-off-by: smajumdar <[email protected]>
Co-authored-by: Somshubra Majumdar <[email protected]>
Signed-off-by: SeanNaren <[email protected]>

* Fix current tests

Signed-off-by: SeanNaren <[email protected]>

* More test coverage

Signed-off-by: SeanNaren <[email protected]>

* Address reviews

Signed-off-by: SeanNaren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review

Signed-off-by: SeanNaren <[email protected]>

* Drop bf16 test

Signed-off-by: SeanNaren <[email protected]>

* Address review

Signed-off-by: SeanNaren <[email protected]>

* remove print

Signed-off-by: SeanNaren <[email protected]>

* Add bf16

Signed-off-by: SeanNaren <[email protected]>

Signed-off-by: SeanNaren <[email protected]>
Signed-off-by: ericharper <[email protected]>
Signed-off-by: smajumdar <[email protected]>
Signed-off-by: nithinraok <[email protected]>
Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Micha Livne <[email protected]>
Signed-off-by: shanmugamr1992 <[email protected]>
Signed-off-by: MaximumEntropy <[email protected]>
Signed-off-by: Jin Li <[email protected]>
Signed-off-by: subhankar-ghosh <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>
Co-authored-by: Somshubra Majumdar <[email protected]>
Co-authored-by: Nithin Rao <[email protected]>
Co-authored-by: Xuesong Yang <[email protected]>
Co-authored-by: Micha Livne <[email protected]>
Co-authored-by: shanmugamr1992 <[email protected]>
Co-authored-by: MaximumEntropy <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: liji-nv <[email protected]>
Co-authored-by: Subhankar Ghosh <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Fix BF16 test (#5162)

Signed-off-by: SeanNaren <[email protected]>

Signed-off-by: SeanNaren <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Fix errors in speaker diarization nemo docs (#5153)

* fix docs and docstrings for MSDD

Signed-off-by: Taejin Park <[email protected]>

* fix nemo docs errors

Signed-off-by: Taejin Park <[email protected]>

* reflected review comments

Signed-off-by: Taejin Park <[email protected]>

Signed-off-by: Taejin Park <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* Add interleaved pipeline schedule to GPT (#5025)

* add virtual pipeline size to config

Signed-off-by: ericharper <[email protected]>

* convert model to list of modules

Signed-off-by: ericharper <[email protected]>

* convert model to list of modules

Signed-off-by: ericharper <[email protected]>

* convert model to list of modules

Signed-off-by: ericharper <[email protected]>

* update for list of modules

Signed-off-by: ericharper <[email protected]>

* add virtual to init

Signed-off-by: ericharper <[email protected]>

* update first last stage embedding all reduce

Signed-off-by: ericharper <[email protected]>

* update sequence parallel all reduce for virtual models

Signed-off-by: ericharper <[email protected]>

* runs but we get an error

Signed-off-by: ericharper <[email protected]>

* set virtual rank 0 after looping

Signed-off-by: ericharper <[email protected]>

* account for virtual when determinining first and last pipeline stages

Signed-off-by: ericharper <[email protected]>

* checkpointing for virtual models in progress

Signed-off-by: ericharper <[email protected]>

* add checkpoint hooks

Signed-off-by: ericharper <[email protected]>

* working on validation when resuming

Signed-off-by: ericharper <[email protected]>

* skip sanity val steps by default in config

Signed-off-by: ericharper <[email protected]>

* remove comment

Signed-off-by: ericharper <[email protected]>

* log number of params

Signed-off-by: ericharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* style

Signed-off-by: ericharper <[email protected]>

* check if self.model is a list

Signed-off-by: ericharper <[email protected]>

* make virtual pipeline default size None on init

Signed-off-by: ericharper <[email protected]>

* make virtual pipeline default to None in config

Signed-off-by: ericharper <[email protected]>

* remove ensure_divisibility call

Signed-off-by: ericharper <[email protected]>

* fix lgtm alerts

Signed-off-by: ericharper <[email protected]>

* remove num_sanity_val_steps from config

Signed-off-by: ericharper <[email protected]>

* default virtual pipeline size to none

Signed-off-by: ericharper <[email protected]>

* check for list

Signed-off-by: ericharper <[email protected]>

* update assert to make sure we are only doing virtual for gpt

Signed-off-by: ericharper <[email protected]>

* revert change to get_params_for_weight_decay

Signed-off-by: ericharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* init var

Signed-off-by: ericharper <[email protected]>

* add import guard for set virtual model parallel world size

Signed-off-by: ericharper <[email protected]>

* use import guard

Signed-off-by: ericharper <[email protected]>

* update calls to fake init in eval scripts

Signed-off-by: ericharper <[email protected]>

* add _get_fwd_bwd_function

Signed-off-by: ericharper <[email protected]>

* log all total model parameters

Signed-off-by: ericharper <[email protected]>

* remove unused import

Signed-off-by: ericharper <[email protected]>

Signed-off-by: ericharper <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Hainan Xu <[email protected]>

* reduced to 14 inactive days to be stale for PRs. (#5165)

Signed-off-by: Xuesong Yang <[email protected]>

Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>

* refactor TTS documentation organization and add new contents. (#5137)

* refactor TTS documentation organization and add new contents.
* fix asr api bug.
* fix broken links.
* fix unexpected indentation errors.
* fixed unexpected indentation.
* fixed broken paper reference.
* fixed cross-reference and typos.
* fixed toctree errors.
* revert to 'Augmentors'
* reordered TTS tutorial list in starthere.
* ordered api classes alphabetically for each Section.
* fixed underscore typo for fastpitch checkpoint.

Signed-off-by: Xuesong Yang <[email protected]>

* upcase 'Tuning'

Signed-off-by: Xuesong Yang <[email protected]>

* fixed typo for RAD-TTS Aligner

Signed-off-by: Xuesong Yang <[email protected]>

* reorder aligner section after mel-gen and vocoders in models.rst.

Signed-off-by: Xuesong Yang <[email protected]>

* clarify Mixer-TTS-X and reorder model descriptions alphabetically.

Signed-off-by: Xuesong Yang <[email protected]>

* fixed some typos and formats.

Signed-off-by: Xuesong Yang <[email protected]>

* removed old megatron.rst.

Signed-off-by: Xuesong Yang <[email protected]>

* fixed block quote ends without a blank line warnings.

Signed-off-by: Xuesong Yang <[email protected]>

* remove duplicate reference; fixed missing key nlp-megatron-shoeybi2019megatron

Signed-off-by: Xuesong Yang <[email protected]>

* Revert "removed old megatron.rst."

This reverts commit c5ea1dc3f23272eecfe8040e3abfa54fa122cf73.

Signed-off-by: Xuesong Yang <[email protected]>

* removed Russian, a hyphen, and add a note about G2P in tts/…
  • Loading branch information
titu1994 authored and zhehuaichen committed Oct 4, 2023
1 parent 354eafc commit e4be603
Show file tree
Hide file tree
Showing 12 changed files with 263 additions and 112 deletions.
26 changes: 21 additions & 5 deletions nemo/collections/asr/losses/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@
from nemo.collections.asr.losses.rnnt_pytorch import MultiblankRNNTLossPytorch, RNNTLossPytorch, TDTLossPytorch
from nemo.core.classes import Loss, typecheck
from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType
from nemo.core.utils import numba_utils
from nemo.core.utils.k2_utils import K2_INSTALLATION_MESSAGE
from nemo.core.utils.numba_utils import NUMBA_INSTALLATION_MESSAGE
from nemo.utils import logging, model_utils
from nemo.utils import logging, logging_mode, model_utils

try:
import warprnnt_pytorch as warprnnt
Expand Down Expand Up @@ -98,7 +99,7 @@ class RNNTLossConfig:
min_version='0.53.0',
is_available=NUMBA_RNNT_AVAILABLE,
installation_msg=NUMBA_INSTALLATION_MESSAGE,
force_float32=True,
force_float32=not numba_utils.NUMBA_FP16_SUPPORTED,
),
"pytorch": RNNTLossConfig(
loss_name="pytorch",
Expand Down Expand Up @@ -387,7 +388,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str =
for the standard "blank" symbol. In particular, say V is the number of non-blank tokens in
the vocabulary, then in the case of,
standard RNNT: num_classes = V
multiblank RNNT: num_classes = V + number-big-blanks (since we store big-blanks before
multiblank RNNT: num_classes = V + number-big-blanks (since we store big-blanks before
standard blank, and the standard blank is the last symbol in the vocab)
TDT: num_classes = V. Note, V here does not include any of the "duration outputs".
Expand All @@ -413,6 +414,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str =
self.reduction = reduction
self._loss = resolve_rnnt_loss(loss_name, blank_idx=self._blank, loss_kwargs=loss_kwargs)
self._force_float32 = RNNT_LOSS_RESOLVER[loss_name].force_float32
self._fp16_compat_checked = False

def reduce(self, losses, target_lengths):

Expand Down Expand Up @@ -442,8 +444,22 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):
max_targets_len = target_lengths.max()

# Force cast joint to float32
# TODO: Remove once Numba supports FP16
if self._force_float32 and log_probs.dtype != torch.float32:
if not self._force_float32 and numba_utils.NUMBA_FP16_SUPPORTED:
# Execute the kernel in fp16
pass
elif self._force_float32 and log_probs.dtype != torch.float32:
# Log just once if fp16 tensor was passed and fp16 Numba CUDA loss could not be used.
if log_probs.dtype == torch.float16 and not self._fp16_compat_checked:
_, reason = numba_utils.is_numba_cuda_fp16_supported(return_reason=True)
logging.warning(
f"Provided RNNT Joint tensor is of dtype {log_probs.dtype}, but RNNT loss could not be calculated "
f"in fp16 due to following reason stated below. Loss will be calculated in fp32. \n\n"
f"{reason}",
mode=logging_mode.ONCE,
)
self._fp16_compat_checked = True

# Upcast the activation tensor and compute loss and grads in fp32
logits_orig = log_probs
log_probs = log_probs.float()
del logits_orig # save memory *before* computing the loss
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/asr/losses/rnnt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def __init__(self, blank, reduction):
self.reduction = reduction

def forward(self, acts, labels, act_lens, label_lens):
# CPU patch for FP16
if not acts.is_cuda and acts.dtype == torch.float16:
acts = acts.float()

acts = torch.log_softmax(acts, -1)

forward_logprob = self.compute_forward_prob(acts, labels, act_lens, label_lens)
losses = -forward_logprob
if self.reduction == 'mean_batch':
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def rnnt_loss_gpu(

# Select GPU index
cuda.select_device(acts.device.index)
gpu_workspace = torch.zeros(gpu_size, device=acts.device, dtype=acts.dtype, requires_grad=False)
gpu_workspace = torch.zeros(gpu_size, device=acts.device, dtype=torch.float32, requires_grad=False)

### VIEW TENSORS AS VECTORS FOR POINTER INDEXING ###
acts, acts_shape = rnnt_helper.flatten_tensor(acts)
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,15 @@ def forward(self, acts, labels, act_lens, label_lens):
_assert_no_grad(label_lens)
certify_inputs(acts, labels, act_lens, label_lens)

# CPU Patch for fp16 - force cast to fp32
if not acts.is_cuda and acts.dtype == torch.float16:
acts = acts.float()

if self.clamp > 0.0:
acts = LogSoftmaxGradModification.apply(acts, self.clamp)

acts = torch.nn.functional.log_softmax(acts, -1)

return self.rnnt(acts, labels, act_lens, label_lens, self.blank, self.fastemit_lambda)


Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_
loss_func = rnnt.rnnt_loss_gpu if is_cuda else rnnt.rnnt_loss_cpu
grads = torch.zeros_like(acts) if acts.requires_grad else None
minibatch_size = acts.size(0)
costs = torch.zeros(minibatch_size, device=acts.device, dtype=acts.dtype)
costs = torch.zeros(minibatch_size, device=acts.device, dtype=torch.float32)

loss_func(
acts,
Expand Down Expand Up @@ -119,7 +119,6 @@ def forward(
label_lens: Tensor of (batch) containing label length of each example
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
durations: list of durations for TDT model, must include 0 and 1, e.g.
[0, 1, 2, 3, 4].
sigma: hyper-parameter for logit under-normalization method for training
Expand Down Expand Up @@ -417,6 +416,10 @@ def forward(self, acts, labels, act_lens, label_lens):
label_lens: Tensor of (batch) containing label length of each example
"""
if not acts.is_cuda:
# Force FP32 until log_softmax() is implemented for fp16 on CPU
if acts.dtype == torch.float16:
acts = acts.float()

# Since CPU requires log_softmax to be computed explicitly, we need to perform grad clipping
# *after* we have obtained the gradients of loss(logsoftmax()).
# This is highly wasteful since it requires a copy of the entire joint tensor which is expensive.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def cost_and_grad_kernel(
)

# Scale llForward by FastEmit lambda
llForward *= 1.0 + self.fastemit_lambda_
llBackward *= 1.0 + self.fastemit_lambda_
llForward += llForward * self.fastemit_lambda_
llBackward += llBackward * self.fastemit_lambda_

diff = (llForward - llBackward).abs()
if diff > 0.1:
Expand Down Expand Up @@ -300,6 +300,10 @@ def compute_betas_and_grads(
Returns:
Loglikelihood of the forward variable and inplace updates the grad tensor.
"""
# Patch for CPU + fp16
if log_probs.dtype == torch.float16 and not log_probs.is_cuda:
log_probs = log_probs.float()

idx = CpuRNNT_index(U, self.maxU_, self.minibatch_, self.alphabet_size_, self.batch_first)
betas[idx(T - 1, U - 1)] = log_probs[idx(T - 1, U - 1) * 2]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import math
from typing import Optional, Tuple

import numba
import torch
from numba import cuda

Expand Down Expand Up @@ -112,7 +113,7 @@ def compute_costs_data(source: torch.Tensor, dest: torch.Tensor, fastemit_lambda
if idx < length:
copy_data_1d(source, dest, idx)
dest[idx] *= -1.0
dest[idx] *= 1.0 + fastemit_lambda
dest[idx] *= numba.float32(1.0 + fastemit_lambda)


def get_workspace_size(
Expand Down
36 changes: 36 additions & 0 deletions nemo/core/utils/numba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import operator
import os

from typing import Tuple, Union

from nemo.utils import model_utils

# Prevent Numba CUDA logs from showing at info level
Expand All @@ -26,6 +28,11 @@
__NUMBA_DEFAULT_MINIMUM_VERSION__ = "0.53.0"
__NUMBA_MINIMUM_VERSION__ = os.environ.get("NEMO_NUMBA_MINVER", __NUMBA_DEFAULT_MINIMUM_VERSION__)

__NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__ = "0.57.0"
NUMBA_FP16_SUPPORTED = model_utils.check_lib_version(
'numba', __NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__, operator=operator.ge
)[0]


NUMBA_INSTALLATION_MESSAGE = (
"Could not import `numba`.\n"
Expand Down Expand Up @@ -148,6 +155,35 @@ def numba_cuda_is_supported(min_version: str) -> bool:
return False


def is_numba_cuda_fp16_supported(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]:
"""
Utility method that returns a bool, stating if FP16 is supported for numba cuda kernels or not.
Returns:
bool, whether Numba CUDA will support fp16 or not.
"""
reason = ""
use_nvidia_binding = os.environ.get('NUMBA_CUDA_USE_NVIDIA_BINDING', None)
if use_nvidia_binding is not None:
use_nvidia_binding = use_nvidia_binding.lower() == "1"
reason += "Env variable `NUMBA_CUDA_USE_NVIDIA_BINDING` is available and set to `1`. "
else:
use_nvidia_binding = False
reason += "Env variable `NUMBA_CUDA_USE_NVIDIA_BINDING` is not available or has not set to `1`."

if NUMBA_FP16_SUPPORTED:
reason += f"Numba CUDA FP16 is supported in installed numba version."
else:
reason += f"Numba CUDA FP16 is not supported in installed numba version."

result = use_nvidia_binding and NUMBA_FP16_SUPPORTED

if return_reason:
return result, reason
else:
return result


def skip_numba_cuda_test_if_unsupported(min_version: str):
"""
Helper method to skip pytest test case if numba cuda is not supported.
Expand Down
Loading

0 comments on commit e4be603

Please sign in to comment.