diff --git a/README.rst b/README.rst index 7ac95b8cef70..ac611e073127 100644 --- a/README.rst +++ b/README.rst @@ -86,9 +86,9 @@ Key Features * Hybrid Transducer/CTC * NeMo Original `Multi-blank Transducers `_ and `Token-and-Duration Transducers (TDT) `_ * Streaming/Buffered ASR (CTC/Transducer) - `Chunked Inference Examples `_ - * Cache-aware Streaming Conformer - ``_ + * Cache-aware Streaming Conformer with multiple lookaheads - ``_ * Beam Search decoding - * `Language Modelling for ASR `_: N-gram LM in fusion with Beam Search decoding, Neural Rescoring with Transformer + * `Language Modelling for ASR (CTC and RNNT) `_: N-gram LM in fusion with Beam Search decoding, Neural Rescoring with Transformer * `Support of long audios for Conformer with memory efficient local attention `_ * `Speech Classification, Speech Command Recognition and Language Identification `_: MatchboxNet (Command Recognition), AmberNet (LangID) * `Voice activity Detection (VAD) `_: MarbleNet @@ -115,11 +115,12 @@ Key Features * `Prompt Learning `_ * `NGC collection of pre-trained NLP models. `_ * `Synthetic Tabular Data Generation `_ -* `Speech synthesis (TTS) `_ - * Spectrogram generation: Tacotron2, GlowTTS, TalkNet, FastPitch, FastSpeech2, Mixer-TTS, Mixer-TTS-X - * Vocoders: WaveGlow, SqueezeWave, UniGlow, MelGAN, HiFiGAN, UnivNet - * End-to-end speech generation: FastPitch_HifiGan_E2E, FastSpeech2_HifiGan_E2E, VITS - * `NGC collection of pre-trained TTS models. `_ +* Text-to-Speech Synthesis (TTS): + * `Documentation `_ + * Mel-Spectrogram generators: FastPitch, SSL FastPitch, Mixer-TTS/Mixer-TTS-X, RAD-TTS, Tacotron2 + * Vocoders: HiFiGAN, UnivNet, WaveGlow + * End-to-End Models: VITS + * `Pre-trained Model Checkpoints in NVIDIA GPU Cloud (NGC) `_ * `Tools `_ * `Text Processing (text normalization and inverse text normalization) `_ * `CTC-Segmentation tool `_ @@ -132,8 +133,8 @@ Built for speed, NeMo can utilize NVIDIA's Tensor Cores and scale out training t Requirements ------------ -1) Python 3.8 or above -2) Pytorch 1.10.0 or above +1) Python 3.9 or above +2) Pytorch 1.13.1 or above 3) NVIDIA GPU for training Documentation diff --git a/docs/source/asr/data/benchmark_en.csv b/docs/source/asr/data/benchmark_en.csv index 684d9f9fa76d..dfd64cc83084 100644 --- a/docs/source/asr/data/benchmark_en.csv +++ b/docs/source/asr/data/benchmark_en.csv @@ -34,4 +34,5 @@ stt_en_fastconformer_ctc_xlarge,EncDecCTCModelBPE,"https://ngc.nvidia.com/catalo stt_en_fastconformer_transducer_xxlarge,EncDecRNNTBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_transducer_xxlarge" stt_en_fastconformer_hybrid_large_streaming_80ms,EncDecHybridRNNTCTCBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_80ms" stt_en_fastconformer_hybrid_large_streaming_480ms,EncDecHybridRNNTCTCBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_480ms" -stt_en_fastconformer_hybrid_large_streaming_1040ms,EncDecHybridRNNTCTCBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_1040ms" \ No newline at end of file +stt_en_fastconformer_hybrid_large_streaming_1040ms,EncDecHybridRNNTCTCBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_1040ms" +stt_en_fastconformer_hybrid_large_streaming_multi,EncDecHybridRNNTCTCBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_multi" \ No newline at end of file diff --git a/docs/source/asr/data/scores/en/conformer_en.csv b/docs/source/asr/data/scores/en/conformer_en.csv index d77f0a687ce8..2b31a07b842a 100644 --- a/docs/source/asr/data/scores/en/conformer_en.csv +++ b/docs/source/asr/data/scores/en/conformer_en.csv @@ -18,3 +18,11 @@ stt_en_fastconformer_hybrid_large_streaming_1040ms (CTC),en,,,,,2.7 %,6.4 %,,,9. stt_en_fastconformer_hybrid_large_streaming_80ms (RNNT),en,,,,,2.7 %,6.5 %,,,9.1 %,6.9 %,,,,,,,3.2 %,1.9 % stt_en_fastconformer_hybrid_large_streaming_480ms (RNNT),en,,,,,2.7 %,6.1 %,,,8.5 %,6.7 %,,,,,,,3.1 %,1.8 % stt_en_fastconformer_hybrid_large_streaming_1040ms (RNNT),en,,,,,2.3 %,5.5 %,,,8.0 %,6.6 %,,,,,,,2.9 %,1.6 % +stt_en_fastconformer_hybrid_large_streaming_multi (RNNT - 0ms),en,,,,,,7.0 %,,,,,,,,,,,, +stt_en_fastconformer_hybrid_large_streaming_multi (RNNT - 80ms),en,,,,,,6.4 %,,,,,,,,,,,, +stt_en_fastconformer_hybrid_large_streaming_multi (RNNT - 480),en,,,,,,5.7 %,,,,,,,,,,,, +stt_en_fastconformer_hybrid_large_streaming_multi (RNNT - 1040),en,,,,,,5.4 %,,,,,,,,,,,, +stt_en_fastconformer_hybrid_large_streaming_multi (CTC - 0ms),en,,,,,,8.4 %,,,,,,,,,,,, +stt_en_fastconformer_hybrid_large_streaming_multi (CTC - 80ms),en,,,,,,7.8 %,,,,,,,,,,,, +stt_en_fastconformer_hybrid_large_streaming_multi (CTC - 480),en,,,,,,6.7 %,,,,,,,,,,,, +stt_en_fastconformer_hybrid_large_streaming_multi (CTC - 1040),en,,,,,,6.2 %,,,,,,,,,,,, diff --git a/docs/source/asr/models.rst b/docs/source/asr/models.rst index 708d66307dd3..a4f77625dff8 100644 --- a/docs/source/asr/models.rst +++ b/docs/source/asr/models.rst @@ -175,7 +175,7 @@ We support the following three right context modeling: * fully causal model with zero look-ahead: tokens would not see any future tokens. convolution layers are all causal and right tokens are masked for self-attention. It gives zero latency but with limited accuracy. -To train such a model, you need to set `encoder.att_context_size=[left_context, 0]` and `encoder.conv_context_size=causal` in the config. +To train such a model, you need to set `model.encoder.att_context_size=[left_context,0]` and `model.encoder.conv_context_size=causal` in the config. * regular look-ahead: convolutions would be able to see few future frames, and self-attention would also see the same number of future tokens. @@ -186,13 +186,11 @@ For example for a model of 17 layers with 4x downsampling and 10ms window shift, For example, in a model which chunk size of 20 tokens, tokens at the first position of each chunk would see all the next 19 tokens while the last token would see zero future tokens. This approach is more efficient than regular look-ahead in terms of computations as the activations for most of the look-ahead part would be cached and there is close to zero duplications in the calculations. -In terms of accuracy, this approach gives similar or even better results in term of accuracy than regular look-ahead as each token in each layer have access to more tokens on average. That is why we recommend to use this approach for streaming. - +In terms of accuracy, this approach gives similar or even better results in term of accuracy than regular look-ahead as each token in each layer have access to more tokens on average. That is why we recommend to use this approach for streaming. Therefore we recommend to use the chunk-aware for cache-aware models. ** Note: Latencies are based on the assumption that the forward time of the network is zero and it just estimates the time needed after a frame would be available until it is passed through the model. -Approaches with non-zero look-ahead can give significantly better accuracy by sacrificing latency. The latency can get controlled by the left context size. Increasing the right context would help the accuracy to a limit but would increase the compuation time. - +Approaches with non-zero look-ahead can give significantly better accuracy by sacrificing latency. The latency can get controlled by the left context size. Increasing the right context would help the accuracy to a limit but would increase the computation time. In all modes, left context can be controlled by the number of tokens to be visible in the self-attention and the kernel size of the convolutions. For example, if left context of self-attention in each layer is set to 20 tokens and there are 10 layers of Conformer, then effective left context is 20*10=200 tokens. @@ -202,23 +200,39 @@ Left context of convolutions is dependent to the their kernel size while it can Self-attention left context of around 6 secs would give close result to have unlimited left context. For a model with 4x downsampling and shift window of 10ms in the preprocessor, each token corresponds to 4*10=40ms. If striding approach is used for downsampling, all the convolutions in downsampling would be fully causal and don't see future tokens. -You may use stacking for downsampling in the streaming models which is significantly faster and uses less memory. -It also does not some of the the limitations with striding and vggnet and you may use any downsampling rate. -You may find the example config files of cache-aware streaming Conformer models at -``/examples/asr/conf/conformer/streaming/conformer_transducer_bpe_streaming.yaml`` for Transducer variant and -at ``/examples/asr/conf/conformer/streaming/conformer_ctc_bpe.yaml`` for CTC variant. +* Multiple Look-aheads +We support multiple look-aheads for cahce-aware models. You may specify a list of context sizes for att_context_size. +During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. +For example you may enable multiple look-aheads by setting `model.encoder.att_context_size=[[70,13],[70,6],[70,1],[70,0]]` for the training. +The first item in the list would be the default during test/validation/inference. To switch between different look-aheads, you may use the method `asr_model.encoder.set_default_att_context_size(att_context_size)` or set the att_context_size like the following when using the script `speech_transcribe.py`: + +.. code-block:: bash + + python [NEMO_GIT_FOLDER]/examples/asr/transcribe_speech.py \ + pretrained_name="stt_en_fastconformer_hybrid_large_streaming_multi" \ + audio_dir="" \ + att_context_size=[70,0] + +.. + +You may find the example config files for cache-aware streaming FastConformer models at +``/examples/asr/conf/fastconformer/cache_aware_streaming/conformer_transducer_bpe_streaming.yaml`` for Transducer variant and +at ``/examples/asr/conf/conformer/cache_aware_streaming/conformer_ctc_bpe.yaml`` for CTC variant. It is recommended to use FastConformer as they are more than 2X faster in both training and inference than regular Conformer. +The hybrid versions of FastConformer can be found here: ``/examples/asr/conf/conformer/hybrid_cache_aware_streaming/`` + +Examples for regular Conformer can be found at +``/examples/asr/conf/conformer/cache_aware_streaming/conformer_transducer_bpe_streaming.yaml`` for Transducer variant and +at ``/examples/asr/conf/conformer/cache_aware_streaming/conformer_ctc_bpe.yaml`` for CTC variant. To simulate cache-aware streaming, you may use the script at ``/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py``. It can simulate streaming in single stream or multi-stream mode (in batches) for an ASR model. This script can be used for models trained offline with full-context but the accuracy would not be great unless the chunk size is large enough which would result in high latency. It is recommended to train a model in streaming model with limited context for this script. More info can be found in the script. -You may find FastConformer variants of cache-aware streaming models under ``/examples/asr/conf/fastconformer/``. - Note cache-aware streaming models are being exported without caching support by default. To include caching support, `model.set_export_config({'cache_support' : 'True'})` should be called before export. Or, if ``/scripts/export.py`` is being used: -`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --config cache_support=True` +`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --export-config cache_support=True` .. _LSTM-Transducer_model: @@ -299,7 +313,7 @@ Similar example configs for FastConformer variants of Hybrid models can be found Note Hybrid models are being exported as RNNT (encoder and decoder+joint parts) by default. To export as CTC (single encoder+decoder graph), `model.set_export_config({'decoder_type' : 'ctc'})` should be called before export. Or, if ``/scripts/export.py`` is being used: -`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --config decoder_type=ctc` +`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --export-config decoder_type=ctc` .. _Conformer-HAT_model: diff --git a/docs/source/core/export.rst b/docs/source/core/export.rst index f54daffe9c9c..202099b13d66 100644 --- a/docs/source/core/export.rst +++ b/docs/source/core/export.rst @@ -207,7 +207,7 @@ An example can be found in ``/nemo/collections/asr/models/rnnt_mo Here is example on now `set_export_config()` call is being tied to command line arguments in ``/scripts/export.py`` : .. code-block:: Python - python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --config decoder_type=ctc + python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --export-config decoder_type=ctc Exportable Model Code ~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/nlp/api.rst b/docs/source/nlp/api.rst index 0822ade0224c..b13dedca300f 100755 --- a/docs/source/nlp/api.rst +++ b/docs/source/nlp/api.rst @@ -124,7 +124,7 @@ Datasets .. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTDataset :show-inheritance: -.. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTChatDataset +.. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset.GPTSFTChatDataset :show-inheritance: .. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.retro_dataset.RETRODataset diff --git a/docs/source/nlp/nemo_megatron/retro/retro_model.rst b/docs/source/nlp/nemo_megatron/retro/retro_model.rst index edbec3d1c2ca..ceff1baf857f 100644 --- a/docs/source/nlp/nemo_megatron/retro/retro_model.rst +++ b/docs/source/nlp/nemo_megatron/retro/retro_model.rst @@ -1,2 +1,444 @@ -Coming Soon ... -================ \ No newline at end of file +NeMo RETRO Model +================ + +The Retrieval-Enhanced Transformer (RETRO) model is an autoregressive language model that takes into account document chunks retrieved from a large +corpus when making predictions. The RETRO model has a similar architecture to the GPT model, but it includes an encoder that encodes the retrieved +context and cross-attention layers that integrate the context to improve the model's output. Below is a simple diagram of the RETRO model architecture. + +.. image:: images/arch.png + :align: center + :width: 800px + :alt: RETRO model architecture + +For more detailed information on the model, please refer to the `RETRO paper `_ :cite:`nlp-retro-borgeaud2021improving` by Deepmind. +The NeMo RETRO Model is an open-source implementation of the paper, and it has the following differences/features compared to Deepmind's proposed implementation: + +1. The NeMo RETRO Model is built on top of NeMo Megatron code, allowing for efficient training of large language models in a cluster environment. +2. The NeMo RETRO Model uses `Faiss `_ :cite:`nlp-retro-jegou2022faiss` as the K$N search library, which can be accelerated by GPUs. +3. The NeMo RETRO uses `RoPe relative positional encoding `_ :cite:`nlp-retro-su2021roformer`. +4. The NeMo RETRO uses `SentenceTransformers `_ :cite:`nlp-retro-reimers2019sentence` as the retriever encoder. +5. The NeMo RETRO supports `mu-Transfer `_ :cite:`nlp-retro-yang2022tensor`, allowing for scalable training of the RETRO model via Zero-Shot Hyperparameter Transfer. + +Quick start +************ +Steps below demonstrate training and evaluating a NeMo RETRO model + +Data pre-processing +------------------- + +Step 1: Collect training data +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The RETRO model uses two types of data: training data, which typically consists of 64-token chunks, and retrieval data, which typically consists of 128-token chunks. +The training data is used to train the model, while the retrieval data is used to supplement the language model. +It's possible to use the same data for both training and retrieval, as long as duplicates are removed properly, as described below. +Both types of data are stored in a loose JSON format, with each line containing a single text sample. For example: + +.. code-block:: json + {"src": "www.nvidia.com", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part"} + {"src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part"} +The name of the text field of the json can be changed by using the ``--json-key`` flag in ``preprocess_data_for_megatron.py``. The other metadata are optional and are not used in training. + +Step 2: Convert training data into memory map format +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The loose json is then processed into a binary format for training and retrieval. To convert the json into mmap, cached index file. +Set the ``--dataset-impl`` flag to `retmmap`, which is the memory map format dedicated for RETRO model. + +An example script to prepare data for RETRO training is: + +.. code-block:: bash + python scripts/nlp_language_modeling/preprocess_data_for_megatron.py \ + --input=/dataset/pubmed_train.jsonl \ + --json-keys=text \ + --tokenizer-library=megatron \ + --apply-ftfy \ + --dataset-impl=retmmap \ + --merge-file=/dataset/gpt2-merges.txt \ + --vocab-file=/dataset/gpt2-vocab.json \ + --tokenizer-type=GPT2BPETokenizer \ + --output-prefix=/result/pubmed_train \ + --need-pad-id \ + --append-eod \ + --retrieval-db \ + --chunk_size=64 \ + --workers=48 +The RETRO model processes chunked documents using 64 tokens as the default chunk size. The RETRO memory map dataset will add padding +tokens to the end of each document to make it a multiple of 64. The ``--need-pad-id`` argument adds a padding token to the tokenizer +if it doesn't already have one. The ``--append-eod`` argument controls whether to add ``end-of-document`` tokens to the preprocessed +data, and the ``--retrieval-db`` argument indicates whether to create a retrieval database for the preprocessed data. If ``--retrieval-db`` +is used, it will add an additional 64 padding tokens at the end of the document. The ``--chunk_size`` and ``--workers`` arguments +control the size of the data chunks to be processed and the number of worker processes to use, respectively. + +Following is the retro memory map index data format: + +.. list-table:: + :widths: 25 25 25 25 25 25 + + * - 'MMIDRET\x00\x00' (header 9 bytes) + - 1 (version 8 byte) + - dtype code :sup:`1` (1 byte) + - sentence count (8 byte) + - chunk size (8 byte) + - chunk count (8 byte) + * - retrieved db :sup:`2` (1 byte) + - number of tokens for each of sentences ( int32 array) + - start of sentence address in byte (int64 array) + - start of chunk id (int64 array) + - chunk id address in byte (int64 array) + - + +:sup:`1` 1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16 + +:sup:`2` When building the indexed dataset, we pad each sentence to be a multiple of ``chunk_size`` with ``pad_id`` from the tokenizer. +The number of tokens for each sentence includes the padded token ids. For retrieval data, there is an extra ``chunk_size`` padding at +the end of each sentence, and the ``retrieved_db`` flag is set to True. However, the number of tokens for each sentence excludes this extra ``chunk_size`` padding. + +Following is the retro memory map binary data format: + +.. list-table:: + :widths: 65 + + * - token id array for sentence 0,1, 2 ... (dtype :sup:`3` array) + +:sup:`3` np.uint16 vocab_size < 65500 else np.int32 + +Step 3: Create Faiss index for retrieval data +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +After creating the memory map retrieval data binary file and index files, we can build a Faiss index that can quickly find the K-nearest neighbors of a given +chunk ID based on a query embedding vector. Because the retrieval data is typically very large, we break this process down into three steps. + +Step 3.1: Train the Faiss index structure +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In this step, it uses a subset of the retrieval data to train a empty Faiss index. An example script is: + +.. code-block:: bash + python scripts/nlp_language_modeling/build_retrieval_index.py \ + --input_file=/result/pubmed_train_text_document \ + --tokenizer-library=megatron \ + --tokenizer-type=GPT2BPETokenizer \ + --merge-file=/dataset/gpt2-merges.txt \ + --vocab-file=/dataset/gpt2-vocab.json \ + --percent=1.0 \ + --sentence_transformer_model=all-mpnet-base-v2 \ + --batch_size=1024 \ + --train_index_size=2000000 \ + --workers=2 \ + --devices=0,1,2,3,4,5,6,7 \ + --stage=0 \ + --output_file=/result/pubmed_faiss_learn.index +This command is used to build an empty Faiss index using the 2000000 training data in ``pubmed_train_text_document``. +The ``all-mpnet-base-v2`` sentence transformer model is used to encode the chunk tokens into an embedding vector. +The index will be saved in the result directory as ``pubmed_faiss_learn.index``. This command specifies using 8 GPUs to train the Faiss index. + +Step 3.2: Add retrieval data into sharding index +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This step adds all the retrieval data to the empty Faiss index created in the previous step. An example script is: + +.. code-block:: bash + python scripts/nlp_language_modeling/build_retrieval_index.py \ + --input_file=/result/pubmed_train_text_document \ + --tokenizer-library=megatron \ + --tokenizer-type=GPT2BPETokenizer \ + --merge-file=/dataset/gpt2-merges.txt \ + --vocab-file=/dataset/gpt2-vocab.json \ + --percent=1.0 \ + --sentence_transformer_model=all-mpnet-base-v2 \ + --batch_size=1024 \ + --shard_id=0 \ + --total_shards=10 \ + --workers=2 \ + --devices=0,1,2,3,4,5,6,7 \ + --stage=1 \ + --learned_index=/result/pubmed_faiss_learn.index \ + --output_file=/result/pubmed_faiss_shard0.save +This command breaks the retrieval data into ``total_shards`` shards and adds the data in the shard specified by ``shard_id``. +The result is saved to a file specified by ``output_file``. In the example above, 10 sharding indexes are created. + +Step 3.3: Merge the sharding indexes into final Faiss index +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This step merges all the sharding indexes created in the previous step into the final Faiss index. An example script is: + +.. code-block:: bash + python scripts/nlp_language_modeling/build_retrieval_index.py \ + --stage=2 \ + --devices=0,1,2,3,4,5,6,7 \ + --learned_index=/result/pubmed_faiss_learn.index \ + --shard_index_input=/result/pubmed_faiss_shard \ + --output_file=/result/pubmed_faiss_final.index +Step 4: Build KNN index +^^^^^^^^^^^^^^^^^^^^^^^ + +During training, it is inefficient to run a query to find the K-nearest neighbor chunk IDs for each training data point. +This can be pre-calculated by building a KNN index before training. The KNN index maps the training data chunk IDs to the K-nearest neighbor chunk IDs +in the retrieval data. As with building the Faiss index, this process is divided into two steps. + +Following is the KNN index data format: + +.. list-table:: + :widths: 25 25 25 25 45 + + * - 'KNNRETM\x00\x00' (header 9 bytes) + - 1 (version 8 byte) + - K number of neighbors (8 byte) + - Number chunks (8 byte) + - Map to K retrieval data chunk IDs, shape (number_chunks, K) ( int64 array) + +Step 4.1: Build KNN sharding index +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The KNN index is built using the memory-mapped training data created by the ``preprocess_data_for_megatron.py`` script and the Faiss index +file for the retrieval data built by the ``build_retrieval_index.py`` script. + +An example script is: + +.. code-block:: bash + python scripts/nlp_language_modeling/build_knn_map_index.py \ + --input_file=/result/pubmed_eval_text_document \ + --tokenizer-library=megatron \ + --tokenizer-type=GPT2BPETokenizer \ + --merge-file=/dataset/gpt2-merges.txt \ + --vocab-file=/dataset/gpt2-vocab.json \ + --process_chunk_size=10000 \ + --sentence_transformer_model=all-mpnet-base-v2 \ + --batch_size=1024 \ + --K_neighbors=50 \ + --workers=2 \ + --devices=0,1,2,3,4,5,6,7 \ + --remove_duplicate \ + --dedup_margin=70 \ + --nprobe=100 \ + --shard_id=0 \ + --total_shards=10 \ + --stage=1 \ + --output_file=/dataset/pubmed_knn_shard0.save \ + --faiss_index=/result/pubmed_faiss_final.index +In this example, the training data is broken into ``total_shards`` shards, and the KNN index is calculated for the shard specified by ``shard_id``. +The result is saved to a file specified by ``output_file``. In the example above, 10 KNN sharding indexes are created. + +Use the ``remove_duplicate`` flag if the training data and retrieval data are the same to remove neighbors from the same document. + +Step 4.2: Merge KNN sharding index into final KNN index +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +An example script is: + +.. code-block:: bash + python scripts/nlp_language_modeling/build_knn_map_index.py \ + --stage=2 \ + --output_file=pubmed_knn_final.save \ + --shard_index_input=pubmed_knn_shard +Train NeMo RETRO Model +----------------------- + +Once the training data, retrieval data, KNN index, and Faiss index are prepared, we are ready to train the RETRO model. In the NeMo implementation, +the RETRO model can be pre-trained with or without the `mu-Transfer `_ :cite:`nlp-retro-yang2022tensor` feature. We will introduce both ways. + + +The table below lists some of the common parameters that can be configured for model pre-training. + ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| **Parameter** | **Default** | **Description** | ++==================================+=============+========================================================================================+ +| model.micro_batch_size | 4 | the micro batch size used for training | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.tensor_model_parallel_size | 1 | tensor model parallel size | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.encoder_seq_length | 2048 | token sequence length | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.chunk_size | 64 | the chunk size used to retrieve | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.enc_num_layers | 4 | total number of encoder layers | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.dec_num_layers | 6 | total number of decoder layers | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.enc_cross_attention | [3] | layer numbers for cross attention in encoder | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.dec_cross_attention | [3,4,5] | layer numbers for chunked cross attention in decoder | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.add_position_embedding | FALSE | whether to add the absolute position encoding | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.hidden_size | 768 | model hidden size | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.ffn_hidden_size | 3072 | model FFN hidden size. Usually 4 * hidden_size | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.num_attention_heads | 12 | number of attention heads | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.init_method_std | 0.02 | standard deviation of the zero mean normal distribution used for weight initialization | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.hidden_dropout | 0.1 | dropout probability for hidden state transformer | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.attention_dropout | 0.1 | dropout probability in the attention layer | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ +| model.ffn_dropout | 0 | dropout probability in the feed-forward layer | ++----------------------------------+-------------+----------------------------------------------------------------------------------------+ + + +Option 1: Train the NeMo RETRO model *without* mu-Transfer +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +An example RETRO pre-training script is: + +.. code-block:: bash + python examples/nlp/language_modeling/megatron_retro_pretraining.py \ + trainer.devices=8 \ + trainer.num_nodes=2 \ + trainer.accelerator=gpu \ + trainer.max_steps=800000 \ + trainer.precision=16 \ + exp_manager.exp_dir=/result/retro_model \ + model.apply_query_key_layer_scaling=False \ + model.tensor_model_parallel_size=8 \ + model.optim.name=adamw \ + model.enc_num_layers=2 \ + model.dec_num_layers=32 \ + model.enc_cross_attention=[0] \ + model.dec_cross_attention=[8,11,14,17,20,23,26,29,31] \ + model.hidden_size=4096 \ + model.ffn_hidden_size=16384 \ + model.num_attention_heads=32 \ + model.tokenizer.merge_file=/dataset/gpt2-merges.txt \ + model.tokenizer.vocab_file=/dataset/gpt2-vocab.json \ + model.data.data_prefix=[/result/pubmed_eval_text_document] \ + model.data.knn_index=[dataset/pubmed_knn_final.save] \ + model.data.retrieval_prefix=/result/pubmed_eval_text_document \ + model.micro_batch_size=8 +During the training, launch Tensorboard to monitor training like so: + +.. code-block:: bash + tensorboard --logdir /result/retro_model --bind_all +.. note:: Weights and Biases (WandB) is supported too. Add ``exp_manager.create_wandb_logger=True`` to the model training arguments to enable it. + +After the training, the model nemo file can be found at the result checkpoint directory. + +Option 2: Train the NeMo RETRO model *with* mu-Transfer +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +`mu-Transfer `_ :cite:`nlp-retro-yang2022tensor` paper proposed a method to zero-shot transfer hyperparameter to train a larger model. +This can be done in 3 steps in NeMo RETRO implementation. + + +Step 1. find optimal hyper parameter for a small base model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Use the pre-training code in Option 1, either manually or automatically ind a set of optimal hyperparameter for a small base RETRO +model. This is can be done cheaply ans fast due to the small model size. + +Step 2. calculate the shape file that can be used to run mu-Transfer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The shape file determines which hyperparameters will be scaled up, allowing the model to adjust the learning rate, weight scaling factor, etc. + +Here is an example shape file calculation script: + + +.. code-block:: bash + python examples/nlp/language_modeling/megatron_retro_cal_shape.py \ + trainer.devices=8 \ + trainer.num_nodes=1 \ + trainer.accelerator=gpu \ + exp_manager.exp_dir=/result/retro_model \ + base_model.enc_num_layers=2 \ + delta_model.enc_num_layers=2 \ + base_model.dec_num_layers=32 \ + delta_model.dec_num_layers=32 \ + base_model.tensor_model_parallel_size=8 \ + delta_model.tensor_model_parallel_size=8 \ + base_model.dec_cross_attention=[8,11,14,17,20,23,26,29,31] \ + delta_model.dec_cross_attention=[8,11,14,17,20,23,26,29,31] \ + base_model.enc_cross_attention=[0] \ + delta_model.enc_cross_attention=[0] \ + base_model.hidden_size=768 \ + base_model.ffn_hidden_size=3072 \ + delta_model.hidden_size=96 \ + delta_model.ffn_hidden_size=384 \ + base_model.num_attention_heads=16 \ + delta_model.num_attention_heads=16 \ + model.shape_file=tp8_32depth_o1_rel_shape_info.yaml +In this example, the ``base_model`` refers to the small base model for which an optimal set of hyperparameters has been determined. +The ``delta_model`` refers to a model with certain hyperparameters that have been scaled up or down. In this case, +the ``hidden_size`` and ``ffn_hidden_size`` have been changed in the ``delta_model``, allowing these two parameters to be scaled freely later. + +Step 3. Pretrain mu-Transfer RETRO model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Once the shape file is created, we can start training a RETRO model. The model training can be scale up freely using the hyperparameters +specified by the delta model and the shape file. + +An example mu-Transfer pre-training script is: + +.. code-block:: bash + python examples/nlp/language_modeling/megatron_retro_mutransfer_pretrain.py \ + trainer.devices=8 \ + trainer.num_nodes=2 \ + trainer.accelerator=gpu \ + trainer.max_steps=500000 \ + trainer.precision=16 \ + exp_manager.exp_dir=/result/retro_model \ + model.apply_query_key_layer_scaling=False \ + model.tensor_model_parallel_size=8 \ + model.optim.name=muadamw \ + model.enc_num_layers=2 \ + model.dec_num_layers=32 \ + model.enc_cross_attention=[0] \ + model.dec_cross_attention=[8,11,14,17,20,23,26,29,31] \ + model.hidden_size=4096 \ + model.ffn_hidden_size=16384 \ + model.num_attention_heads=32 \ + model.tokenizer.merge_file=/dataset/gpt2-merges.txt \ + model.tokenizer.vocab_file=/dataset/gpt2-vocab.json \ + model.data.data_prefix=[/result/pubmed_eval_text_document] \ + model.data.knn_index=[dataset/pubmed_knn_final.save] \ + model.data.retrieval_prefix=/result/pubmed_eval_text_document \ + model.micro_batch_size=8 \ + model.shape_file=tp8_32depth_o1_rel_shape_info.yaml +.. note:: We have chosen to use ``muadamw`` as the optimizer for use with the mu-transfer method. Currently, only ``muadam`` and ``muadamw`` are supported. + +Similarly to the pre-training in Option 1, the model nemo file can be found at the result checkpoint directory after training is complete. + +Run NeMo RETRO Model Inference +------------------------------- + +Once the NeMo RETRO model has been trained, we can put it into inference mode and experiment with it. +During inference, we are not limited to the static Faiss index that we built earlier for KNN queries. +We can feed any external data to the model as retrieval context. NeMo RETRO implementation supports dynamic retrieval service, +allowing users to add, reset, and query new documents on the fly. + +We have built a simple web client that makes it easy for users to play around with the model. Here is an example script to launch the server: + +.. code-block:: bash + python examples/nlp/language_modeling/megatron_retro_eval.py \ + trainer.devices=8 \ + trainer.num_nodes=1 \ + trainer.accelerator=gpu \ + trainer.precision=16 \ + retro_model_file=megatron_retro.nemo \ + tensor_model_parallel_size=8 \ + pipeline_model_parallel_size=1 \ + retrieval_service.sentence_bert.devices=\'0,1,2,3,4,5,6,7\' \ + retrieval_service.services.0.faiss_devices=\'0,1,2,3,4,5,6,7\' \ + retrieval_service.services.1.faiss_devices=\'0,1,2,3,4,5,6,7\' \ + retrieval_service.services.0.faiss_index=/result/pubmed_faiss_final.index \ + retrieval_service.services.0.retrieval_index=/result/pubmed_eval_text_document \ + retrieval_service.neighbors=2 \ + retrieval_service.pad_tokens=True \ + retrieval_service.store_retrieved=True \ + server=True \ + web_server=True \ + share=True \ + username=test \ + password=test123 +Set the retro_model_file to use the nemo file generated in the pre-training step. After launching the server, copy-paste the URL from +the terminal into your browser. Use the specified username and password to log in and have fun experimenting with the RETRO model. + +References +************ + +.. bibliography:: ../../nlp_all.bib + :style: plain + :labelprefix: nlp-retro + :keyprefix: nlp-retro- diff --git a/docs/source/starthere/intro.rst b/docs/source/starthere/intro.rst index 2e0e272c93f4..70426d3fe4a0 100644 --- a/docs/source/starthere/intro.rst +++ b/docs/source/starthere/intro.rst @@ -34,9 +34,9 @@ Prerequisites Before you begin using NeMo, it's assumed you meet the following prerequisites. -#. You have Python version 3.6, 3.7 or 3.8. +#. You have Python version 3.9, 3.10. -#. You have Pytorch version 1.8.1. +#. You have Pytorch version 1.13.1 or 2.0+. #. You have access to an NVIDIA GPU for training. diff --git a/docs/source/starthere/tutorials.rst b/docs/source/starthere/tutorials.rst index e24637718690..3a0998197732 100644 --- a/docs/source/starthere/tutorials.rst +++ b/docs/source/starthere/tutorials.rst @@ -106,6 +106,12 @@ To run a tutorial: * - ASR - Multi-lingual ASR - `Multi-lingual ASR `_ + * - ASR + - Hybrid ASR-TTS Models + - `Hybrid ASR-TTS Models `_ + * - ASR + - ASR Confidence Estimation + - `ASR Confidence Estimation `_ * - ASR - Confidence-based Ensembles - `Confidence-based Ensembles `_ diff --git a/docs/source/tools/speech_data_processor.rst b/docs/source/tools/speech_data_processor.rst index 29bc4abb82bd..262b214c6355 100644 --- a/docs/source/tools/speech_data_processor.rst +++ b/docs/source/tools/speech_data_processor.rst @@ -1,166 +1,10 @@ Speech Data Processor -======================== +===================== Speech Data Processor (SDP) is a toolkit to make it easy to: 1. write code to process a new dataset, minimizing the amount of boilerplate code required. 2. share the steps for processing a speech dataset. -SDP is hosted here: https://github.com/NVIDIA/NeMo-speech-data-processor. +SDP is hosted here: https://github.com/NVIDIA/NeMo-speech-data-processor. -SDP's philosophy is to represent processing operations as 'processor' classes, which take in a path to a NeMo-style data manifest as input (or a path to the raw data directory if you do not have a NeMo-style manifest to start with), apply some processing to it, and then save the output manifest file. - -You specifiy which processors you want to run using a YAML config file. Many common processing operations are provided, and it is easy to add your own. If you do not need to add your own processors, then all that is needed to process a new dataset is to write a single YAML file containing the parameters needed to process your dataset. - -.. image:: https://github.com/NVIDIA/NeMo/releases/download/v1.17.0/sdp_overview_diagram.png - :alt: Overview diagram of Speech Data Processor - -Overview of how SDP processes a dataset ---------------------------------------- - -1. You call the ``main.py`` script, passing in a YAML config file, possibly with some overrides. -2. ``main.py`` script calls ``run_processors.py``, passing in your config. -3. ``run_processors.py`` does the following: - - a. picks out the processors that you specified to be run (you can specify a subset of the processors in the config override, e.g. to avoid re-running time-consuming steps). - b. if some of the processors have not had "output_manifest_file" or "input_manfiest_file" entries specified, SDP will automatically create temporary files for those. - c. instantiates the processor classes using ``hydra.utils.instantiate`` - d. runs the run-time processor tests by calling the ``processor.test()`` method (more details about testing :ref:`here`). - e. runs the processing method (``processor.process()``) of each processor in order. - - -Layout of config YAML files -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The YAML config file for processing a dataset must contain a key ``processors``, the value of which is a list. Each item in that list is expected to be a dictionary specifying a processor class, i.e. it must have a key ``_target_``, the value of which is a path to a "processor" class, and the remaining keys must be the kwargs necessary to instantiate that class with ``hydra.utils.instantiate()`` (c.f. https://hydra.cc/docs/advanced/instantiate_objects/overview/). - -SDP will run the processors specified in the ``processors`` list in the config file. It will also check for a ``processors_to_run`` key in the config file, which can be either the string ``"all"``, or any Python "slice" object like ``3:4``, ``2:`` etc. (if there is no ``processors_to_run`` key, then all of the processors will be run). - -.. note:: - SDP will run the processors in the order in which they are listed in the config YAML file. Make sure to list the processors in an order which makes sense, e.g. create an initial manifest first; make sure to run asr inference before doing any processing which looks at ``pred_text`` fields in the manifest. - -Processor classes ------------------ - -**BaseProcessor** -~~~~~~~~~~~~~~~~~ - -All processor classes inherit from the ``BaseProcessor`` class. This is a simple abstract class which has 2 empty methods: ``process()`` and ``test()``. -These serve to remind us that SDP essentially just runs ``test()`` on all processors, and then ``process()`` on all processors (more details about testing :ref:`here`). - -``ASRInference`` is a child class of ``BaseProcessor``. It has a simple ``process()`` method which runs transcription on every utterance in the input_manifest. - -``WriteManifest`` is also a child class of ``BaseProcessor``. It has a simple ``process()`` method which saves a copy of the input manifest containing only the fields specified in ``fields_to_save``. - -**BaseParallelProcessor** -~~~~~~~~~~~~~~~~~~~~~~~~~ -``BaseParallelProcessor`` inherits from the ``BaseProcessor`` class. Within the ``BaseParallelProcessor.process()`` method, it calls other methods and functions, which allow it to do more complex processing. -Most importantly, it calls its ``BaseParallelProcessor.process_dataset_entry(data_entry)`` method on every utterance in the manifest, and it does this in parallel, allowing for more efficient processing. - -What is a **DataEntry**? -~~~~~~~~~~~~~~~~~~~~~~~~ -As mentioned above, ``BaseParallelProcessor.process_dataset_entry(data_entry)`` is called on a variable called ``data_entry`` which represents an utterance in our dataset. -Most often, ``data_entry`` will be a dictionary containing items which represent the JSON manifest entry. -Sometimes, such as in ``CreateInitialManifestMLS``, it will be a string containing a line for that utterance from the original raw MLS transcript. - -``BaseParallelProcessor.process_dataset_entry`` will process ``data_entry`` and output a ``DataEntry`` object. - -The ``DataEntry`` class is a dataclass which contains 2 attributes: - -1. ``data`` is an Optional dictionary containing items which represent the JSON manifest entry. ``data`` can also be ``None``. If a ``.process_dataset_entry(data_entry)`` method returns a ``DataEntry`` class where ``data is None``, then that utterance will be dropped from the output manifest. -2. ``metrics``, which can be of any type, and are ``None`` by default. This variable is used by some variables to record summary statistics about the changes made to the dataset, these metrics are aggregated and can be displayed once every utterance has been processed by the processor. - -What happens in **BaseParallelProcessor.process()**? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -We outline the ``BaseParallelProcessor.process()`` method below: - -.. raw:: html - -
- -
- - -**ModifyManifestTextProcessor** -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -``ModifyManifestTextProcessor`` inherits from the ``BaseParallelProcessor`` class. - -The ``ModifyManifestTextProcessor`` constructor takes in the following arguments: -* ``text_key`` (string) and ``pred_text_key`` (string): these parameters specify which keys in ``data_entry.data`` will be used for processing. (default: ``text_key="text"``, ``pred_text_key="pred_text"``, ie. by default the processor will refer to and modify the ``"text"`` and/or ``"pred_text"`` attributes of the input manifest). -* ``test_cases`` (optional, list of dicts) - test cases for checking that the processor makes the changes that we are expecting. - -``ModifyManifestTextProcessor`` has the following methods: -* ``ModifyManifestTextProcessor.test()``: this method makes sure that the output from the processor matches the expected output specified in the ``test_cases`` parameter. -* ``ModifyManifestTextProcessor.process_dataset_entry(data_entry)``: this method applies processing to a ``data_entry``. First, spaces are added to the start and end of the 'text' and 'pred_text' entries (if they exist), then the abstract method ``ModifyManifestTextProcessor._process_dataset_entry(data_entry)`` is called. Then, any extra spaces (e.g. two spaces next to each other ' ') are removed from 'text' and 'pred_text' entries. -* ``ModifyManifestTextProcessor._process_dataset_entry(data_entry)``: this is an abstract method which will be over-written by children of ``ModifyManifestTextProcessor``. - -How to make your own processor classes --------------------------------------- - -We will describe how to make your own processor classes by referring to SDP's existing classes. - -Creating an initial manifest -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -One of the child classes of ``BaseParallelProcessor`` provided in SDP is ``CreateInitialManifestMLS``. It downloads raw MLS data for a specified language, and creates an initial manifest (in the format expected by NeMo) which can be cleaned by subsequent processors. - -The ``CreateInitialManifestMLS.prepare()`` method downloads and extracts the raw data. - -The ``CreateInitialManifestMLS.read_manifest()`` method reads the lines in the raw MLS transcript file. - -The ``CreateInitialManifestMLS.process_dataset_entry()`` method takes in the lines from the raw MLS transcript file, and outputs ``DataEntry`` objects containing entries that will be saved into the manifest (i.e. ``"audio_filepath"``, ``"duration"``, ``"text"``) for each utterance. - - -A **ModifyManifestTextProcessor** subclass that cleans the reference text -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -One of the classes provided in SDP is ``SubRegex``. At initialization, it takes in ``regex_params_list``, a list of dictionaries which must contain the keys ``"pattern"``, ``"repl"``, and, optionally, ``"count"``. These keys will be used to apply regex substitutions using these parameters fed into ``re.sub``. The substitutions will be applied to the data at ``text_key`` (i.e. ``data_entry.data[self.text_key]``). By default, ``text_key="text"``, i.e. the substitutions will be applied to the ``"text"`` attribute of the manifest. - -In its ``_process_dataset_entry(data_entry)`` method, the ``SubRegex`` processor does the string to string conversion upon the ``data_entry`` that is input. Its output is a ``data_entry`` with the changes applied to ``data``, and the the metrics of which regex patterns caused a substitution to be made. These metrics will be aggregated over all utterances by the ``BaseParallelProcessor`` class. ``SubRegex`` also has a ``finalize(metrics)`` method which will log information about the aggregated metrics after all of the utterances in the manifest have been processed. - -A **ModifyManifestTextProcessor** subclass that drops incorrectly transcribed utterances -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -One of the classes provided in SDP is ``DropHighLowCharrate``. At initialization, it takes in ``high_charrate_threshold`` and ``low_charrate_threshold``, for which the utterance will be dropped if it is above or below each value respectively. This is helpful for automatically filtering out incorrectly transcribed utterances. - -In its ``_process_dataset_entry(data_entry)`` method it evaluates the character rate of the utterance(by dividing the length of ``data_entry.data[self.text_key]`` by the value of ``data_entry.data["duration"]``). If the character rate is within bounds, it will return the same ``data_entry`` that was input. If the character rate is out of bounds, it will return a ``data_entry`` with ``data=None`` and ``metrics`` which reflect the applied changes. -Similar to the ``SubSubstringToSpace`` class, it has a ``finalize(metrics)`` method which will log information about the aggregated metrics after all of the utterances in the manifest have been processed. - -Class diagram -------------- -A diagram of the classes mentioned above is included here. Arrows represent inheritance. - -We omit the details of the ``CreateInitialManifestMLS`` class in the diagram in order to save space. - - -.. raw:: html - -
- -
- -SDP Tests ---------- -It is important to make sure that your data processing code has the effect you intend, so SDP has a few different types of tests: - -1. Runtime tests - -* Before running the specified processors, SDP runs ``processor.test()`` on all specified processors. -* Currently, the only provided processor classes with a test method are subclasses of ``ModifyManifestTextProcessor``. - - * ``ModifyManifestTextProcessor.test()`` runs any ``test_cases`` that were provided in the object constructor. - * This means you can provided test cases in the YAML config file, and the dataset will only be processed if the test cases pass. - * This is helpful to (a) make sure that the rules you wrote have the effect you desired, and (b) demonstrate why you wrote those rules. - * An example of test cases we could include in the YAML config file:: - - - _target_: sdp.processors.DropIfRegexMatch - regex_patterns: - - "(\\D ){5,20}" # looks for between 4 and 19 characters surrounded by spaces - test_cases: - - {input: {text: "some s p a c e d out letters"}, output: null} - - {input: {text: "normal words only"}, output: {text: "normal words only"}} - -2. ``pytest`` tests which can be run locally with ``python -m pytest tests/`` and will be run during the GitHub CI process. There are 2 sub-types: - - a. "End to end" tests (link) which run SDP on a mini version of the raw initial dataset, and make sure the final manifest matches the reference final manifest. - b. "Unit tests" for processors and utils (link). +To learn more about SDP, please check the [documentation](https://nvidia.github.io/NeMo-speech-data-processor/). diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py index 75912f1c03c1..7726c2b2740e 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py @@ -46,10 +46,13 @@ It may result in slightly different outputs from the sub-sampling module compared to offline mode for some techniques like striding and sw_striding. Enabling it would make it easier to export the model to ONNX. -# Hybrid ASR models +## Hybrid ASR models For Hybrid ASR models which have two decoders, you may select the decoder by --set_decoder DECODER_TYPE, where DECODER_TYPE can be "ctc" or "rnnt". If decoder is not set, then the default decoder would be used which is the RNNT decoder for Hybrid ASR models. +## Multi-lookahead models +For models which support multiple lookaheads, the default is the first one in the list of model.encoder.att_context_size. To change it, you may use --att_context_size, for example --att_context_size [70,1]. + ## Evaluate a model trained with full context for offline mode @@ -58,7 +61,7 @@ The accuracy of the model on the borders of chunks would not be very good. To use a model trained with full context, you need to pass the chunk_size and shift_size arguments. -If shift_size is not passed, chunk_size would be use as the shift_size too. +If shift_size is not passed, chunk_size would be used as the shift_size too. Also argument online_normalization should be enabled to simulate a realistic streaming. The following command would simulate cache-aware streaming on a pretrained model from NGC with chunk_size of 100, shift_size of 50 and 2 left chunks as left context. The chunk_size of 100 would be 100*4*10=4000ms for a model with 4x downsampling and 10ms shift in feature extraction. @@ -273,6 +276,13 @@ def main(): help="Selects the decoder for Hybrid ASR models which has both the CTC and RNNT decoder. Supported decoders are ['ctc', 'rnnt']", ) + parser.add_argument( + "--att_context_size", + type=str, + default=None, + help="Sets the att_context_size for the models which support multiple lookaheads", + ) + args = parser.parse_args() if (args.audio_file is None and args.manifest_file is None) or ( args.audio_file is not None and args.manifest_file is not None @@ -293,6 +303,12 @@ def main(): else: raise ValueError("Decoder cannot get changed for non-Hybrid ASR models.") + if args.att_context_size is not None: + if hasattr(asr_model.encoder, "set_default_att_context_size"): + asr_model.encoder.set_default_att_context_size(att_context_size=json.loads(args.att_context_size)) + else: + raise ValueError("Model does not support multiple lookaheads.") + global autocast if ( args.use_amp diff --git a/examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py b/examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py new file mode 100644 index 000000000000..199e399ead11 --- /dev/null +++ b/examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py @@ -0,0 +1,184 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +A script to convert a Nemo ASR Hybrid model file (.nemo) to a Nemo ASR CTC or RNNT model file (.nemo) + +This allows you to train a RNNT-CTC Hybrid model, but then convert it into a pure CTC or pure RNNT model for use +in NeMo. The resulting .nemo file will be a pure CTC or RNNT model, and can be used like any other .nemo model +including in nemo2riva. + +Usage: python convert_nemo_asr_hybrid_to_ctc.py -i /path/to/hybrid.nemo -o /path/to/saved_ctc_model.nemo -m ctc|rnnt + +""" + + +import argparse +import os +from copy import deepcopy + +import torch +from omegaconf import OmegaConf + +from nemo.collections.asr.models import ( + ASRModel, + EncDecCTCModel, + EncDecCTCModelBPE, + EncDecRNNTBPEModel, + EncDecRNNTModel, +) +from nemo.utils import logging + + +def extract_model_ctc(args, hybrid_model): + """ + A function which converts a hybrid model to a pure ctc model. + Args: + args (argparse): the args collection from ArgumentParser created by running this script + hybrid_model (ASRModel): the loaded hybrid RNNT-CTC Nemo model + """ + BPE = False + ctc_class = EncDecCTCModel + if 'tokenizer' in hybrid_model.cfg.keys(): + BPE = True + ctc_class = EncDecCTCModelBPE + + hybrid_model_cfg = OmegaConf.to_container(hybrid_model.cfg) + + new_cfg = deepcopy(hybrid_model_cfg) + new_cfg['ctc_reduction'] = hybrid_model_cfg['aux_ctc']['ctc_reduction'] + new_cfg['decoder'] = hybrid_model_cfg['aux_ctc']['decoder'] + del new_cfg['compute_eval_loss'] + del new_cfg['model_defaults'] + del new_cfg['joint'] + del new_cfg['decoding'] + del new_cfg['aux_ctc'] + del new_cfg['loss'] + if BPE and 'labels' in new_cfg: + del new_cfg['labels'] + elif (not BPE) and 'tokenizer' in new_cfg: + del new_cfg['tokenizer'] + del new_cfg['target'] + del new_cfg['nemo_version'] + + new_cfg_oc = OmegaConf.create(new_cfg) + + # we call restore_from with strict=False because the .nemo file we're restoring from is a hybrid model, which will have named + # tensors in the state_dict that do not exist in the pure CTC model class, which would result in an exception with strict=True + ctc_model = ctc_class.restore_from( + args.input, map_location=torch.device('cpu'), override_config_path=new_cfg_oc, strict=False + ) + + assert all( + [ + torch.allclose(hybrid_model.state_dict()[x], ctc_model.state_dict()[x]) + for x in hybrid_model.state_dict().keys() + if x.split('.')[0] in ['preprocessor', 'encoder'] + ] + ), "Encoder and preprocessor state dicts don't match!" + + ctc_model.decoder.load_state_dict(hybrid_model.ctc_decoder.state_dict()) + + assert all( + [ + torch.allclose(hybrid_model.ctc_decoder.state_dict()[x], ctc_model.decoder.state_dict()[x]) + for x in hybrid_model.ctc_decoder.state_dict().keys() + ] + ), "Decoder state_dict load failed!" + + assert isinstance(ctc_model, ctc_class), "Extracted CTC model is of the wrong expected class!" + + return ctc_model + + +def extract_model_rnnt(args, hybrid_model): + """ + A function which converts a hybrid model to a pure rnnt model. + Args: + args (argparse): the args collection from ArgumentParser created by running this script + hybrid_model (ASRModel): the loaded hybrid RNNT-CTC Nemo model + """ + BPE = False + rnnt_class = EncDecRNNTModel + if 'tokenizer' in hybrid_model.cfg.keys(): + BPE = True + rnnt_class = EncDecRNNTBPEModel + + hybrid_model_cfg = OmegaConf.to_container(hybrid_model.cfg) + + new_cfg = deepcopy(hybrid_model_cfg) + del new_cfg['aux_ctc'] + if BPE and 'labels' in new_cfg: + del new_cfg['labels'] + elif (not BPE) and 'tokenizer' in new_cfg: + del new_cfg['tokenizer'] + del new_cfg['target'] + del new_cfg['nemo_version'] + + new_cfg_oc = OmegaConf.create(new_cfg) + + # we call restore_from with strict=False because the .nemo file we're restoring from is a hybrid model, which will have named + # tensors in the state_dict that do not exist in the pure RNNT model class, which would result in an exception with strict=True + rnnt_model = rnnt_class.restore_from( + args.input, map_location=torch.device('cpu'), override_config_path=new_cfg_oc, strict=False + ) + + assert all( + [ + torch.allclose(hybrid_model.state_dict()[x], rnnt_model.state_dict()[x]) + for x in hybrid_model.state_dict().keys() + if x.split('.')[0] in ['preprocessor', 'encoder', 'decoder', 'joint'] + ] + ), "State dict values mismatch, something went wrong!" + + assert isinstance(rnnt_model, rnnt_class), "Extracted RNNT model is of the wrong expected class!" + + return rnnt_model + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input', required=True, type=str, help='path to Nemo Hybrid model .nemo file') + parser.add_argument('-o', '--output', required=True, type=str, help='path and name of output .nemo file') + parser.add_argument( + '-t', + '--model_type', + required=False, + type=str, + default='ctc', + choices=['ctc', 'rnnt'], + help='whether to output a ctc or rnnt model from the hybrid', + ) + + args = parser.parse_args() + + if not os.path.exists(args.input): + logging.critical(f'Input file [ {args.input} ] does not exist or cannot be found. Aborting.') + exit(255) + + hybrid_model = ASRModel.restore_from(args.input, map_location=torch.device('cpu')) + + if args.model_type == 'ctc': + output_model = extract_model_ctc(args, hybrid_model) + elif args.model_type == 'rnnt': + output_model = extract_model_rnnt(args, hybrid_model) + else: + logging.critical( + f"the model_type arg must be one of 'ctc' or 'rnnt', received unknown value: '{args.model_type}'. Aborting." + ) + exit(255) + + output_model.save_to(args.output) + logging.info(f'Converted {args.model_type.upper()} model was successfully saved to {args.output}') diff --git a/examples/asr/conf/speech_translation/fast-conformer_transformer.yaml b/examples/asr/conf/speech_translation/fast-conformer_transformer.yaml new file mode 100644 index 000000000000..4e480df62e59 --- /dev/null +++ b/examples/asr/conf/speech_translation/fast-conformer_transformer.yaml @@ -0,0 +1,218 @@ +# It contains the default values for training an autoregressive FastConformer-Transformer ST model with sub-word encoding. + +# Architecture and training config: +# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective +# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. +# Here are the recommended configs for different variants of FastConformer-Transformer, other parameters are the same as in this config file. +# One extra (linear projection) layer is added between FastConformer encoder and Transformer decoder if they have different hidden sizes +# It is recommended to initialize FastConformer with ASR pre-trained encoder for better accuracy and faster convergence + +name: "FastConformer-Transformer-BPE-st" + +# Initialize model encoder with pre-trained ASR FastConformer encoder for faster convergence and improved accuracy +init_from_nemo_model: + model0: + path: ??? + include: ["preprocessor", "encoder"] + +model: + sample_rate: 16000 + label_smoothing: 0.0 + log_prediction: true # enables logging sample predictions in the output during training + + train_ds: + is_tarred: true + tarred_audio_filepaths: ??? + manifest_filepath: ??? + sample_rate: 16000 + shuffle: false + trim_silence: false + batch_size: 4 + num_workers: 8 + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 4 + pin_memory: true + use_start_end_token: true + + test_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 4 + pin_memory: true + use_start_end_token: true + + # recommend small vocab size of 128 or 256 when using 4x sub-sampling + # you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + pad_value: 0.0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + # you may use lower time_masks for smaller models to have a faster convergence + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + + # Sub-sampling params + subsampling: dw_striding # vggnet or striding, vggnet may give better results but needs more memory + subsampling_factor: 8 # must be power of 2 + subsampling_conv_channels: 256 # -1 sets it to d_model + causal_downsampling: false + reduction: null + reduction_position: null + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: batch_norm + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + transf_encoder: + num_layers: 0 + hidden_size: 512 + inner_size: 2048 + num_attention_heads: 8 + ffn_dropout: 0.1 + attn_score_dropout: 0.1 + attn_layer_dropout: 0.1 + + transf_decoder: + library: nemo + model_name: null + pretrained: false + max_sequence_length: 512 + num_token_types: 0 + embedding_dropout: 0.1 + learn_positional_encodings: false + hidden_size: 512 + inner_size: 2048 + num_layers: 6 + num_attention_heads: 4 + ffn_dropout: 0.1 + attn_score_dropout: 0.1 + attn_layer_dropout: 0.1 + hidden_act: relu + pre_ln: true + pre_ln_final_layer_norm: true + + head: + num_layers: 1 + activation: relu + log_softmax: true + dropout: 0.0 + use_transformer_init: true + + beam_search: + beam_size: 4 + len_pen: 0.0 + max_generation_delta: 50 + + optim: + name: adam + lr: 0.0001 + # optimizer arguments + betas: [0.9, 0.98] + # less necessity for weight_decay as we already have large augmentations with SpecAug + # you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used + # weight decay of 0.0 with lr of 2.0 also works fine + #weight_decay: 1e-3 + + # scheduler setup + sched: + name: InverseSquareRootAnnealing + #d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 1000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + gpus: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 100 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 100 # Interval of logging. + enable_progress_bar: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_sacreBLEU" + mode: "max" + save_top_k: 3 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + + # you need to set these two to True to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null \ No newline at end of file diff --git a/examples/asr/conf/vad/frame_vad_infer_postprocess.yaml b/examples/asr/conf/vad/frame_vad_infer_postprocess.yaml index d759a809ec37..30c082aff91f 100644 --- a/examples/asr/conf/vad/frame_vad_infer_postprocess.yaml +++ b/examples/asr/conf/vad/frame_vad_infer_postprocess.yaml @@ -1,6 +1,7 @@ name: &name "vad_inference_postprocessing" -dataset: null # Path of json file of evaluation data. Audio files should have unique names +input_manifest: null # Path of json file of evaluation data. Audio files should have unique names +output_dir: null # Path to output directory where results will be stored num_workers: 12 sample_rate: 16000 evaluate: False # whether to get AUROC and DERs, the manifest must contains groundtruth if enabled diff --git a/examples/asr/speech_classification/frame_vad_infer.py b/examples/asr/speech_classification/frame_vad_infer.py index f716eb45bb64..594cc9637d73 100644 --- a/examples/asr/speech_classification/frame_vad_infer.py +++ b/examples/asr/speech_classification/frame_vad_infer.py @@ -21,7 +21,8 @@ ## Usage: python frame_vad_infer.py \ --config-path="../conf/vad" --config-name="frame_vad_infer_postprocess" \ - dataset= + input_manifest= \ + output_dir= The manifest json file should have the following format (each line is a Python dictionary): {"audio_filepath": "/path/to/audio_file1", "offset": 0, "duration": 10000} @@ -58,15 +59,25 @@ @hydra_runner(config_path="../conf/vad", config_name="frame_vad_infer_postprocess") def main(cfg): - if not cfg.dataset: + if not cfg.input_manifest: raise ValueError("You must input the path of json file of evaluation data") + output_dir = cfg.output_dir if cfg.output_dir else "frame_vad_outputs" + if os.path.exists(output_dir): + logging.warning( + f"Output directory {output_dir} already exists, use this only if you're tuning post-processing params." + ) + Path(output_dir).mkdir(parents=True, exist_ok=True) + + cfg.frame_out_dir = os.path.join(output_dir, "frame_preds") + cfg.smoothing_out_dir = os.path.join(output_dir, "smoothing_preds") + cfg.rttm_out_dir = os.path.join(output_dir, "rttm_preds") - # each line of dataset should be have different audio_filepath and unique name to simplify edge cases or conditions - logging.info(f"Loading manifest file {cfg.dataset}") + # each line of input_manifest should be have different audio_filepath and unique name to simplify edge cases or conditions + logging.info(f"Loading manifest file {cfg.input_manifest}") manifest_orig, key_labels_map, key_rttm_map = frame_vad_infer_load_manifest(cfg) # Prepare manifest for streaming VAD - manifest_vad_input = cfg.dataset + manifest_vad_input = cfg.input_manifest if cfg.prepare_manifest.auto_split: logging.info("Split long audio file to avoid CUDA memory issue") logging.debug("Try smaller split_duration if you still have CUDA memory issue") @@ -76,6 +87,7 @@ def main(cfg): 'split_duration': cfg.prepare_manifest.split_duration, 'num_workers': cfg.num_workers, 'prepared_manifest_vad_input': cfg.prepared_manifest_vad_input, + 'out_dir': output_dir, } manifest_vad_input = prepare_manifest(config) else: @@ -171,7 +183,7 @@ def main(cfg): key_pred_rttm_map[key] = entry['rttm_filepath'] if not cfg.out_manifest_filepath: - out_manifest_filepath = "manifest_vad_output.json" + out_manifest_filepath = os.path.join(output_dir, "manifest_vad_output.json") else: out_manifest_filepath = cfg.out_manifest_filepath write_manifest(out_manifest_filepath, manifest_new) diff --git a/examples/asr/speech_to_text_eval.py b/examples/asr/speech_to_text_eval.py index f4d2a66ffec0..452aa8202660 100644 --- a/examples/asr/speech_to_text_eval.py +++ b/examples/asr/speech_to_text_eval.py @@ -76,6 +76,11 @@ class EvaluationConfig(transcribe_speech.TranscriptionConfig): dataset_manifest: str = MISSING output_filename: Optional[str] = "evaluation_transcripts.json" + # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Joint RNNT/CTC models + decoder_type: Optional[str] = None + # att_context_size can be set for cache-aware streaming models with multiple look-aheads + att_context_size: Optional[list] = None + use_cer: bool = False tolerance: Optional[float] = None diff --git a/examples/asr/speech_translation/speech_to_text_transformer.py b/examples/asr/speech_translation/speech_to_text_transformer.py new file mode 100644 index 000000000000..0c0882859b88 --- /dev/null +++ b/examples/asr/speech_translation/speech_to_text_transformer.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Training the model +```sh +python speech_to_text_transformer.py \ + # (Optional: --config-path= --config-name=) \ + model.train_ds.audio.tarred_audio_filepaths= \ + model.train_ds.audio_manifest_filepath= \ + model.validation_ds.manifest_filepath= \ + model.test_ds.manifest_filepath= \ + model.tokenizer.dir= \ + model.tokenizer.model_path= \ + model.tokenizer.type= \ + trainer.gpus=-1 \ + trainer.accelerator="ddp" \ + trainer.max_epochs=100 \ + model.optim.name="adamw" \ + model.optim.lr=0.001 \ + model.optim.betas=[0.9,0.999] \ + model.optim.weight_decay=0.0001 \ + model.optim.sched.warmup_steps=2000 + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name="" \ + exp_manager.wandb_logger_kwargs.project="" +``` + + +""" + +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from nemo.collections.asr.models import EncDecTransfModelBPE +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="../conf/speech_translation/", config_name="fast-conformer_transformer") +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + asr_model = EncDecTransfModelBPE(cfg=cfg.model, trainer=trainer) + + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(asr_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if asr_model.prepare_test(trainer): + trainer.test(asr_model) + + +if __name__ == '__main__': + main() diff --git a/examples/asr/speech_translation/translate_speech.py b/examples/asr/speech_translation/translate_speech.py new file mode 100644 index 000000000000..203852b52ee9 --- /dev/null +++ b/examples/asr/speech_translation/translate_speech.py @@ -0,0 +1,210 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import json +import os +from dataclasses import dataclass, is_dataclass +from typing import List, Optional, Union + +import pytorch_lightning as pl +import torch +from omegaconf import OmegaConf + +from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig +from nemo.collections.asr.parts.utils.transcribe_utils import compute_output_filename, prepare_audio_data, setup_model +from nemo.core.config import hydra_runner +from nemo.utils import logging + +""" +Translate audio file on a single CPU/GPU. Useful for translations of moderate amounts of audio data. + +# Arguments + model_path: path to .nemo ST checkpoint + pretrained_name: name of pretrained ST model (from NGC registry) + audio_dir: path to directory with audio files + dataset_manifest: path to dataset JSON manifest file (in NeMo format) + + output_filename: Output filename where the translations will be written + batch_size: batch size during inference + + cuda: Optional int to enable or disable execution of model on certain CUDA device. + allow_mps: Bool to allow using MPS (Apple Silicon M-series GPU) device if available + amp: Bool to decide if Automatic Mixed Precision should be used during inference + audio_type: Str filetype of the audio. Supported = wav, flac, mp3 + + overwrite_translations: Bool which when set allows repeated translations to overwrite previous results. + +# Usage +ST model can be specified by either "model_path" or "pretrained_name". +Data for translation can be defined with either "audio_dir" or "dataset_manifest". +Results are returned in a JSON manifest file. + +python translate_speech.py \ + model_path=null \ + pretrained_name=null \ + audio_dir="" \ + dataset_manifest="" \ + output_filename="" \ + batch_size=32 \ + cuda=0 \ + amp=True \ +""" + + +@dataclass +class ModelChangeConfig: + + # Sub-config for changes specific to the Conformer Encoder + conformer: ConformerChangeConfig = ConformerChangeConfig() + + +@dataclass +class TranslationConfig: + # Required configs + model_path: Optional[str] = None # Path to a .nemo file + pretrained_name: Optional[str] = None # Name of a pretrained model + audio_dir: Optional[str] = None # Path to a directory which contains audio files + dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest + audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest + eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation + + # General configs + output_filename: Optional[str] = None + batch_size: int = 32 + random_seed: Optional[int] = None # seed number going to be used in seed_everything() + + # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA + # device anyway, and do inference on CPU only if CUDA device is not found. + # If `cuda` is a negative number, inference will be on CPU only. + cuda: Optional[int] = None + allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU) + amp: bool = False + audio_type: str = "wav" + + # Recompute model translation, even if the output folder exists with scores. + overwrite_translations: bool = True + + # can be set to True to return list of translations instead of the config + # if True, will also skip writing anything to the output file + return_translations: bool = False + + +@hydra_runner(config_name="TranslationConfig", schema=TranslationConfig) +def main(cfg: TranslationConfig) -> Union[TranslationConfig, List[str]]: + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + for key in cfg: + cfg[key] = None if cfg[key] == 'None' else cfg[key] + + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + if cfg.random_seed: + pl.seed_everything(cfg.random_seed) + + if cfg.model_path is None and cfg.pretrained_name is None: + raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") + if cfg.audio_dir is None and cfg.dataset_manifest is None: + raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") + + # Load augmentor from exteranl yaml file which contains eval info, could be extend to other feature such VAD, P&C + augmentor = None + if cfg.eval_config_yaml: + eval_config = OmegaConf.load(cfg.eval_config_yaml) + augmentor = eval_config.test_ds.get("augmentor") + logging.info(f"Will apply on-the-fly augmentation on samples during translation: {augmentor} ") + + # setup GPU + if cfg.cuda is None: + if torch.cuda.is_available(): + device = [0] # use 0th CUDA device + accelerator = 'gpu' + map_location = torch.device('cuda:0') + elif cfg.allow_mps and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + logging.warning( + "MPS device (Apple Silicon M-series GPU) support is experimental." + " Env variable `PYTORCH_ENABLE_MPS_FALLBACK=1` should be set in most cases to avoid failures." + ) + device = [0] + accelerator = 'mps' + map_location = torch.device('mps') + else: + device = 1 + accelerator = 'cpu' + map_location = torch.device('cpu') + else: + device = [cfg.cuda] + accelerator = 'gpu' + map_location = torch.device(f'cuda:{cfg.cuda}') + + logging.info(f"Inference will be done on device: {map_location}") + + asr_model, model_name = setup_model(cfg, map_location) + trainer = pl.Trainer(devices=device, accelerator=accelerator) + asr_model.set_trainer(trainer) + asr_model = asr_model.eval() + + # collect additional translation information + return_hypotheses = False + + # prepare audio filepaths and decide wether it's partial audio + filepaths, partial_audio = prepare_audio_data(cfg) + + # setup AMP (optional) + if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): + logging.info("AMP enabled!\n") + autocast = torch.cuda.amp.autocast + else: + + @contextlib.contextmanager + def autocast(): + yield + + # Compute output filename + cfg = compute_output_filename(cfg, model_name) + + # if translations should not be overwritten, and already exists, skip re-translation step and return + if not cfg.return_translations and not cfg.overwrite_translations and os.path.exists(cfg.output_filename): + logging.info( + f"Previous translations found at {cfg.output_filename}, and flag `overwrite_translations`" + f"is {cfg.overwrite_translations}. Returning without re-translating text." + ) + return cfg + + # translate audio + with autocast(): + with torch.no_grad(): + translations = asr_model.translate( + paths2audio_files=filepaths, batch_size=cfg.batch_size, return_hypotheses=return_hypotheses, + ) + + logging.info(f"Finished translating {len(filepaths)} files !") + logging.info(f"Writing translations into file: {cfg.output_filename}") + + if cfg.return_translations: + return translations + + # write audio translations + with open(cfg.output_filename, 'w', encoding='utf-8', newline='\n') as f: + for filepath, translation in zip(filepaths, translations): + item = {'audio_filepath': filepath, 'pred_translation': translation} + f.write(json.dumps(item, ensure_ascii=False) + "\n") + logging.info(f"Finished writing predictions to {cfg.output_filename}!") + + return cfg + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index f97dd96ad0f3..8e1be5ca1311 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -153,8 +153,10 @@ class TranscriptionConfig: # Decoding strategy for RNNT models rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1) - # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Joint RNNT/CTC models + # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models decoder_type: Optional[str] = None + # att_context_size can be set for cache-aware streaming models with multiple look-aheads + att_context_size: Optional[list] = None # Use this for model-specific changes before transcription model_change: ModelChangeConfig = ModelChangeConfig() @@ -246,6 +248,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis if cfg.decoder_type and cfg.decoder_type != 'rnnt': raise ValueError('RNNT model only support rnnt decoding!') + if cfg.decoder_type and hasattr(asr_model.encoder, 'set_default_att_context_size'): + asr_model.encoder.set_default_att_context_size(cfg.att_context_size) + # Setup decoding strategy if hasattr(asr_model, 'change_decoding_strategy'): if cfg.decoder_type is not None: diff --git a/examples/asr/transcribe_speech_parallel.py b/examples/asr/transcribe_speech_parallel.py index f14df284c6b1..a57922f20d29 100644 --- a/examples/asr/transcribe_speech_parallel.py +++ b/examples/asr/transcribe_speech_parallel.py @@ -102,8 +102,10 @@ class ParallelTranscriptionConfig: # decoding strategy for RNNT models rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig() - # decoder for hybrid models, must be one of 'ctc', 'rnnt' if not None + # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models decoder_type: Optional[str] = None + # att_context_size can be set for cache-aware streaming models with multiple look-aheads + att_context_size: Optional[list] = None trainer: TrainerConfig = TrainerConfig(devices=-1, accelerator="gpu", strategy="ddp") diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml index 53d4e9b7e82b..b5b053fc1549 100644 --- a/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml @@ -9,7 +9,7 @@ inference: repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False - + end_strings: ["<|endoftext|>"] # generation will stop when one of these tokens is generated trainer: devices: 1 diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index 2a6890e1a9b4..76e68d24bae8 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -267,6 +267,7 @@ def main(cfg) -> None: "add_BOS": cfg.inference.add_BOS, "all_probs": cfg.inference.all_probs, "compute_logprob": cfg.inference.compute_logprob, + "end_strings": cfg.inference.end_strings, } fp8_enabled = hasattr(model.cfg, "fp8") and (model.cfg.fp8 == True) diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml index 799d105aae7c..d26dd2922088 100755 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml @@ -116,6 +116,7 @@ model: micro_batch_size: ${model.micro_batch_size} shuffle: True num_workers: 0 + memmap_workers: null pin_memory: True max_seq_length: 2048 min_seq_length: 1 @@ -143,6 +144,7 @@ model: micro_batch_size: ${model.micro_batch_size} shuffle: False num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} pin_memory: True max_seq_length: 2048 min_seq_length: 1 @@ -170,6 +172,7 @@ model: micro_batch_size: ${model.micro_batch_size} shuffle: False num_workers: 4 + memmap_workers: ${model.data.train_ds.memmap_workers} pin_memory: True max_seq_length: 2048 min_seq_length: 1 diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml index 0e3f0d712dd6..f15138c99264 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml @@ -82,6 +82,7 @@ model: micro_batch_size: ${model.micro_batch_size} shuffle: True num_workers: 4 + memmap_workers: null pin_memory: True max_seq_length: 2048 min_seq_length: 1 @@ -109,6 +110,7 @@ model: micro_batch_size: ${model.micro_batch_size} shuffle: True num_workers: 4 + memmap_workers: ${model.data.train_ds.memmap_workers} pin_memory: True max_seq_length: 2048 min_seq_length: 1 @@ -137,6 +139,7 @@ model: micro_batch_size: ${model.micro_batch_size} shuffle: True num_workers: 4 + memmap_workers: ${model.data.train_ds.memmap_workers} pin_memory: True max_seq_length: 2048 min_seq_length: 1 diff --git a/examples/nlp/spellchecking_asr_customization/create_custom_vocab_index.py b/examples/nlp/spellchecking_asr_customization/create_custom_vocab_index.py index 07d64ec5b723..68c55ff51a4f 100644 --- a/examples/nlp/spellchecking_asr_customization/create_custom_vocab_index.py +++ b/examples/nlp/spellchecking_asr_customization/create_custom_vocab_index.py @@ -53,7 +53,7 @@ print("Size of customization vocabulary:", len(custom_phrases)) # Load n-gram mappings vocabulary -ngram_mapping_vocab, ban_ngram = load_ngram_mappings(args.ngram_mappings, max_misspelled_freq=125000) +ngram_mapping_vocab, ban_ngram = load_ngram_mappings(args.ngram_mappings, max_misspelled_freq=args.max_misspelled_freq) # Generate index of custom phrases phrases, ngram2phrases = get_index( diff --git a/examples/nlp/spellchecking_asr_customization/run_infer.sh b/examples/nlp/spellchecking_asr_customization/run_infer.sh index 09da98171c16..b4bbdc4da375 100644 --- a/examples/nlp/spellchecking_asr_customization/run_infer.sh +++ b/examples/nlp/spellchecking_asr_customization/run_infer.sh @@ -31,7 +31,7 @@ BIG_SAMPLE=spellmapper_asr_customization_en/big_sample.txt ## File with input nemo ASR manifest INPUT_MANIFEST=spellmapper_en_evaluation/medical_manifest_ctc.json ## File containing custom words and phrases (plain text) -CUSTOM_VOCAB=spellmapper_en_evaluation/medical_custom_vocab.json +CUSTOM_VOCAB=spellmapper_en_evaluation/medical_custom_vocab.txt ## Other files will be created ## File with index of custom vocabulary diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index a884f7d3cc68..894be6319c99 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -99,7 +99,7 @@ class RNNTLossConfig: min_version='0.53.0', is_available=NUMBA_RNNT_AVAILABLE, installation_msg=NUMBA_INSTALLATION_MESSAGE, - force_float32=not numba_utils.NUMBA_FP16_SUPPORTED, + force_float32=False, # This is only temporarily false, will be dynamically updated during resolution ), "pytorch": RNNTLossConfig( loss_name="pytorch", @@ -258,6 +258,9 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) _warn_unused_additional_kwargs(loss_name, loss_kwargs) elif loss_name == 'warprnnt_numba': + # Update loss config's forced float32 flag if set to None + loss_config.force_float32 = not numba_utils.is_numba_cuda_fp16_supported() + fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0) clamp = loss_kwargs.pop('clamp', -1.0) loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda, clamp=clamp) @@ -444,7 +447,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths): max_targets_len = target_lengths.max() # Force cast joint to float32 - if not self._force_float32 and numba_utils.NUMBA_FP16_SUPPORTED: + if not self._force_float32 and numba_utils.is_numba_cuda_fp16_supported(): # Execute the kernel in fp16 pass elif self._force_float32 and log_probs.dtype != torch.float32: diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index 7e5636191a1d..87a48e50d58a 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -100,32 +100,33 @@ class AbstractRNNTDecoding(ConfidenceMixin): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - method_cfg: A dict-like object which contains the method name and settings to compute per-frame + measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - 'entropy' for using a normalized entropy of a log-likelihood vector. entropy_type: Which type of entropy to use (str). - Used if confidence_method_cfg.name is set to `entropy`. + Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) entropy_norm: A mapping of the entropy value to the interval [0,1]. @@ -139,7 +140,7 @@ class AbstractRNNTDecoding(ConfidenceMixin): timestep during greedy decoding. Setting to larger values allows longer sentences to be decoded, at the cost of increased execution time. preserve_frame_confidence: Same as above, overrides above value. - confidence_method: Same as above, overrides confidence_cfg.method. + confidence_measure_cfg: Same as above, overrides confidence_cfg.measure_cfg. "beam": beam_size: int, defining the beam size for beam search. Must be >= 1. @@ -255,15 +256,13 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): # initialize confidence-related fields self._init_confidence(self.cfg.get('confidence_cfg', None)) - # Update preserve frame confidence - if self.preserve_frame_confidence is False: - if self.cfg.strategy in ['greedy', 'greedy_batch']: - self.preserve_frame_confidence = self.cfg.greedy.get('preserve_frame_confidence', False) - self.confidence_method_cfg = self.cfg.greedy.get('confidence_method_cfg', None) - - elif self.cfg.strategy in ['beam', 'tsd', 'alsd', 'maes']: - # Not implemented - pass + # Confidence estimation is not implemented for these strategies + if ( + not self.preserve_frame_confidence + and self.cfg.strategy in ['beam', 'tsd', 'alsd', 'maes'] + and self.cfg.beam.get('preserve_frame_confidence', False) + ): + raise NotImplementedError(f"Confidence calculation is not supported for strategy `{self.cfg.strategy}`") if self.cfg.strategy == 'greedy': if self.big_blank_durations is None: @@ -278,7 +277,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_method_cfg=self.confidence_method_cfg, + confidence_measure_cfg=self.confidence_measure_cfg, ) else: self.decoding = greedy_decode.GreedyTDTInfer( @@ -292,7 +291,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_method_cfg=self.confidence_method_cfg, + confidence_measure_cfg=self.confidence_measure_cfg, ) else: self.decoding = greedy_decode.GreedyMultiblankRNNTInfer( @@ -305,7 +304,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_method_cfg=self.confidence_method_cfg, + confidence_measure_cfg=self.confidence_measure_cfg, ) elif self.cfg.strategy == 'greedy_batch': @@ -321,7 +320,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_method_cfg=self.confidence_method_cfg, + confidence_measure_cfg=self.confidence_measure_cfg, ) else: self.decoding = greedy_decode.GreedyBatchedTDTInfer( @@ -335,7 +334,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_method_cfg=self.confidence_method_cfg, + confidence_measure_cfg=self.confidence_measure_cfg, ) else: @@ -349,7 +348,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_method_cfg=self.confidence_method_cfg, + confidence_measure_cfg=self.confidence_measure_cfg, ) elif self.cfg.strategy == 'beam': @@ -1006,32 +1005,33 @@ class RNNTDecoding(AbstractRNNTDecoding): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - method_cfg: A dict-like object which contains the method name and settings to compute per-frame + measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - 'entropy' for using a normalized entropy of a log-likelihood vector. entropy_type: Which type of entropy to use (str). - Used if confidence_method_cfg.name is set to `entropy`. + Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) entropy_norm: A mapping of the entropy value to the interval [0,1]. @@ -1047,7 +1047,7 @@ class RNNTDecoding(AbstractRNNTDecoding): preserve_frame_confidence: Same as above, overrides above value. - confidence_method: Same as above, overrides confidence_cfg.method. + confidence_measure_cfg: Same as above, overrides confidence_cfg.measure_cfg. "beam": beam_size: int, defining the beam size for beam search. Must be >= 1. diff --git a/nemo/collections/asr/metrics/rnnt_wer_bpe.py b/nemo/collections/asr/metrics/rnnt_wer_bpe.py index d2e2c3cc5923..3fb50d2a1ee2 100644 --- a/nemo/collections/asr/metrics/rnnt_wer_bpe.py +++ b/nemo/collections/asr/metrics/rnnt_wer_bpe.py @@ -100,32 +100,33 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - method_cfg: A dict-like object which contains the method name and settings to compute per-frame + measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - 'entropy' for using a normalized entropy of a log-likelihood vector. entropy_type: Which type of entropy to use (str). - Used if confidence_method_cfg.name is set to `entropy`. + Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) entropy_norm: A mapping of the entropy value to the interval [0,1]. @@ -141,7 +142,7 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): preserve_frame_confidence: Same as above, overrides above value. - confidence_method: Same as above, overrides confidence_cfg.method. + confidence_measure_cfg: Same as above, overrides confidence_cfg.measure_cfg. "beam": beam_size: int, defining the beam size for beam search. Must be >= 1. diff --git a/nemo/collections/asr/metrics/wer.py b/nemo/collections/asr/metrics/wer.py index 4d90810cc3df..a88895763edc 100644 --- a/nemo/collections/asr/metrics/wer.py +++ b/nemo/collections/asr/metrics/wer.py @@ -35,14 +35,17 @@ def word_error_rate(hypotheses: List[str], references: List[str], use_cer=False) -> float: """ Computes Average Word Error rate between two texts represented as - corresponding lists of string. Hypotheses and references must have same - length. + corresponding lists of string. + + Hypotheses and references must have same length. + Args: - hypotheses: list of hypotheses - references: list of references - use_cer: bool, set True to enable cer + hypotheses (list): list of hypotheses + references(list) : list of references + use_cer (bool): set True to enable cer + Returns: - (float) average word error rate + wer (float): average word error rate """ scores = 0 words = 0 @@ -78,17 +81,18 @@ def word_error_rate_detail( between two texts represented as corresponding lists of string. Hypotheses and references must have same length. + Args: - hypotheses (list): list of hypotheses - references(list) : list of references - use_cer (bool): set True to enable cer - Returns: - wer (float): average word error rate - words (int): Total number of words/charactors of given reference texts - ins_rate (float): average insertion error rate - del_rate (float): average deletion error rate - sub_rate (float): average substitution error rate + hypotheses (list): list of hypotheses + references(list) : list of references + use_cer (bool): set True to enable cer + Returns: + wer (float): average word error rate + words (int): Total number of words/charactors of given reference texts + ins_rate (float): average insertion error rate + del_rate (float): average deletion error rate + sub_rate (float): average substitution error rate """ scores = 0 words = 0 @@ -141,6 +145,68 @@ def word_error_rate_detail( return wer, words, ins_rate, del_rate, sub_rate +def word_error_rate_per_utt(hypotheses: List[str], references: List[str], use_cer=False) -> Tuple[List[float], float]: + """ + Computes Word Error Rate per utterance and the average WER + between two texts represented as corresponding lists of string. + + Hypotheses and references must have same length. + + Args: + hypotheses (list): list of hypotheses + references(list) : list of references + use_cer (bool): set True to enable cer + + Returns: + wer_per_utt (List[float]): word error rate per utterance + avg_wer (float): average word error rate + """ + scores = 0 + words = 0 + wer_per_utt = [] + + if len(hypotheses) != len(references): + raise ValueError( + "In word error rate calculation, hypotheses and reference" + " lists must have the same number of elements. But I got:" + "{0} and {1} correspondingly".format(len(hypotheses), len(references)) + ) + + for h, r in zip(hypotheses, references): + if use_cer: + h_list = list(h) + r_list = list(r) + else: + h_list = h.split() + r_list = r.split() + + # To get rid of the issue that jiwer does not allow empty string + if len(r_list) == 0: + if len(h_list) != 0: + errors = len(h_list) + wer_per_utt.append(float('inf')) + else: + if use_cer: + measures = jiwer.cer(r, h, return_dict=True) + er = measures['cer'] + else: + measures = jiwer.compute_measures(r, h) + er = measures['wer'] + + errors = measures['insertions'] + measures['deletions'] + measures['substitutions'] + wer_per_utt.append(er) + + scores += errors + words += len(r_list) + + if words != 0: + avg_wer = 1.0 * scores / words + else: + avg_wer = float('inf') + + return wer_per_utt, avg_wer + + def move_dimension_to_the_front(tensor, dim_index): all_dims = list(range(tensor.ndim)) return tensor.permute(*([dim_index] + all_dims[:dim_index] + all_dims[dim_index + 1 :])) @@ -192,32 +258,33 @@ class AbstractCTCDecoding(ConfidenceMixin): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - method_cfg: A dict-like object which contains the method name and settings to compute per-frame + measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - 'entropy' for using a normalized entropy of a log-likelihood vector. entropy_type: Which type of entropy to use (str). - Used if confidence_method_cfg.name is set to `entropy`. + Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) entropy_norm: A mapping of the entropy value to the interval [0,1]. @@ -233,6 +300,7 @@ class AbstractCTCDecoding(ConfidenceMixin): preserve_alignments: Same as above, overrides above value. compute_timestamps: Same as above, overrides above value. preserve_frame_confidence: Same as above, overrides above value. + confidence_measure_cfg: Same as above, overrides confidence_cfg.measure_cfg. "beam": beam_size: int, defining the beam size for beam search. Must be >= 1. @@ -302,6 +370,14 @@ def __init__(self, decoding_cfg, blank_id: int): # initialize confidence-related fields self._init_confidence(self.cfg.get('confidence_cfg', None)) + # Confidence estimation is not implemented for strategies other than `greedy` + if ( + not self.preserve_frame_confidence + and self.cfg.strategy != 'greedy' + and self.cfg.beam.get('preserve_frame_confidence', False) + ): + raise NotImplementedError(f"Confidence calculation is not supported for strategy `{self.cfg.strategy}`") + # we need timestamps to extract non-blank per-frame confidence if self.compute_timestamps is not None: self.compute_timestamps |= self.preserve_frame_confidence @@ -313,7 +389,7 @@ def __init__(self, decoding_cfg, blank_id: int): preserve_alignments=self.preserve_alignments, compute_timestamps=self.compute_timestamps, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_method_cfg=self.confidence_method_cfg, + confidence_measure_cfg=self.confidence_measure_cfg, ) elif self.cfg.strategy == 'beam': @@ -961,32 +1037,33 @@ class CTCDecoding(AbstractCTCDecoding): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - method_cfg: A dict-like object which contains the method name and settings to compute per-frame + measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - 'entropy' for using a normalized entropy of a log-likelihood vector. entropy_type: Which type of entropy to use (str). - Used if confidence_method_cfg.name is set to `entropy`. + Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) entropy_norm: A mapping of the entropy value to the interval [0,1]. @@ -1002,6 +1079,7 @@ class CTCDecoding(AbstractCTCDecoding): preserve_alignments: Same as above, overrides above value. compute_timestamps: Same as above, overrides above value. preserve_frame_confidence: Same as above, overrides above value. + confidence_measure_cfg: Same as above, overrides confidence_cfg.measure_cfg. "beam": beam_size: int, defining the beam size for beam search. Must be >= 1. diff --git a/nemo/collections/asr/metrics/wer_bpe.py b/nemo/collections/asr/metrics/wer_bpe.py index 8a92e4745a1b..524294d61c50 100644 --- a/nemo/collections/asr/metrics/wer_bpe.py +++ b/nemo/collections/asr/metrics/wer_bpe.py @@ -74,32 +74,33 @@ class CTCBPEDecoding(AbstractCTCDecoding): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - method_cfg: A dict-like object which contains the method name and settings to compute per-frame + measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - 'entropy' for using a normalized entropy of a log-likelihood vector. entropy_type: Which type of entropy to use (str). - Used if confidence_method_cfg.name is set to `entropy`. + Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) entropy_norm: A mapping of the entropy value to the interval [0,1]. @@ -115,6 +116,7 @@ class CTCBPEDecoding(AbstractCTCDecoding): preserve_alignments: Same as above, overrides above value. compute_timestamps: Same as above, overrides above value. preserve_frame_confidence: Same as above, overrides above value. + confidence_measure_cfg: Same as above, overrides confidence_cfg.measure_cfg. "beam": beam_size: int, defining the beam size for beam search. Must be >= 1. diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index a7275faea3d0..34f2c4f62e29 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -33,3 +33,4 @@ from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel +from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE diff --git a/nemo/collections/asr/models/classification_models.py b/nemo/collections/asr/models/classification_models.py index 432674225f5a..264e9cef99f8 100644 --- a/nemo/collections/asr/models/classification_models.py +++ b/nemo/collections/asr/models/classification_models.py @@ -35,6 +35,7 @@ from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.neural_types import * from nemo.utils import logging, model_utils +from nemo.utils.cast_utils import cast_all __all__ = ['EncDecClassificationModel', 'EncDecRegressionModel'] @@ -851,6 +852,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.eval_loop_cnt = 0 self.ratio_threshold = cfg.get('ratio_threshold', 0.2) super().__init__(cfg=cfg, trainer=trainer) + self.decoder.output_types = self.output_types + self.decoder.output_types_for_export = self.output_types @classmethod def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]: @@ -1148,3 +1151,43 @@ def get_metric_logits_labels(self, logits, labels, masks): labels = labels.gather(dim=0, index=idx.view(-1)) return logits, labels + + def forward_for_export( + self, input, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + ): + """ + This forward is used when we need to export the model to ONNX format. + Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models. + Args: + input: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps. + length: Vector of length B, that contains the individual lengths of the audio sequences. + cache_last_channel: Tensor of shape [N, B, T, H] which contains the cache for last channel layers + cache_last_time: Tensor of shape [N, B, H, T] which contains the cache for last time layers + N is the number of such layers which need caching, B is batch size, H is the hidden size of activations, + and T is the length of the cache + + Returns: + the output of the model + """ + enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward) + if cache_last_channel is None: + encoder_output = enc_fun(audio_signal=input, length=length) + if isinstance(encoder_output, tuple): + encoder_output = encoder_output[0] + else: + encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun( + audio_signal=input, + length=length, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + dec_fun = getattr(self.output_module, 'forward_for_export', self.output_module.forward) + ret = dec_fun(hidden_states=encoder_output.transpose(1, 2)) + if isinstance(ret, tuple): + ret = ret[0] + if cache_last_channel is not None: + ret = (ret, length, cache_last_channel, cache_last_time, cache_last_channel_len) + return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32) diff --git a/nemo/collections/asr/models/confidence_ensemble.py b/nemo/collections/asr/models/confidence_ensemble.py index 9b3191c8874d..bf65ff96ef5c 100644 --- a/nemo/collections/asr/models/confidence_ensemble.py +++ b/nemo/collections/asr/models/confidence_ensemble.py @@ -25,7 +25,7 @@ from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel from nemo.collections.asr.parts.utils.asr_confidence_utils import ( ConfidenceConfig, - ConfidenceMethodConfig, + ConfidenceMeasureConfig, get_confidence_aggregation_bank, get_confidence_measure_bank, ) @@ -61,8 +61,8 @@ def to_confidence_config(self) -> ConfidenceConfig: return ConfidenceConfig( exclude_blank=self.exclude_blank, aggregation=self.aggregation, - method_cfg=ConfidenceMethodConfig( - name=name, entropy_type=entropy_type, temperature=self.alpha, entropy_norm=entropy_norm, + measure_cfg=ConfidenceMeasureConfig( + name=name, entropy_type=entropy_type, alpha=self.alpha, entropy_norm=entropy_norm, ), ) @@ -135,12 +135,12 @@ def compute_confidence(hypothesis: Hypothesis, confidence_cfg: ConfidenceConfig) filtered_logprobs = get_filtered_logprobs(hypothesis, confidence_cfg.exclude_blank) vocab_size = filtered_logprobs.shape[1] aggr_func = get_confidence_aggregation_bank()[confidence_cfg.aggregation] - if confidence_cfg.method_cfg.name == "max_prob": + if confidence_cfg.measure_cfg.name == "max_prob": conf_type = "max_prob" alpha = 1.0 else: - conf_type = f"entropy_{confidence_cfg.method_cfg.entropy_type}_{confidence_cfg.method_cfg.entropy_norm}" - alpha = confidence_cfg.method_cfg.temperature + conf_type = f"entropy_{confidence_cfg.measure_cfg.entropy_type}_{confidence_cfg.measure_cfg.entropy_norm}" + alpha = confidence_cfg.measure_cfg.alpha conf_func = get_confidence_measure_bank()[conf_type] conf_value = aggr_func(conf_func(filtered_logprobs, v=vocab_size, t=alpha)).cpu().item() diff --git a/nemo/collections/asr/models/hybrid_asr_tts_models.py b/nemo/collections/asr/models/hybrid_asr_tts_models.py index 8486f956c3b7..8494a093b29d 100644 --- a/nemo/collections/asr/models/hybrid_asr_tts_models.py +++ b/nemo/collections/asr/models/hybrid_asr_tts_models.py @@ -311,8 +311,10 @@ def from_pretrained_models( ) ) else: + cfg = copy.deepcopy(cfg) # copy to avoid modifying original config cfg.tts_model_path = f"{tts_model_path}" cfg.asr_model_path = f"{asr_model_path}" + cfg.enhancer_model_path = f"{enhancer_model_path}" if enhancer_model_path is not None else None return ASRWithTTSModel(cfg, trainer=trainer) def __setattr__(self, name, value): diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 6604983b6461..7f1a22a9b2b8 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -544,4 +544,11 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: ) results.append(model) + model = PretrainedModelInfo( + pretrained_model_name="stt_en_fastconformer_hybrid_large_streaming_multi", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_multi", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_fastconformer_hybrid_large_streaming_multi/versions/1.20.0/files/stt_en_fastconformer_hybrid_large_streaming_multi.nemo", + ) + results.append(model) + return results diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py new file mode 100644 index 000000000000..178746795ae8 --- /dev/null +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -0,0 +1,614 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import json +import os +import tempfile +from math import ceil +from typing import Dict, List, Optional, Union + +import editdistance +import torch +import torch.distributed as dist +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Trainer +from tqdm.auto import tqdm + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs +from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel +from nemo.collections.asr.parts.mixins import ASRBPEMixin +from nemo.collections.common.losses import SmoothedCrossEntropyLoss +from nemo.collections.common.metrics import GlobalAverageLossMetric +from nemo.collections.common.parts import transformer_weights_init + +from nemo.core.classes.common import typecheck +from nemo.core.neural_types import ( + AudioSignal, + ChannelType, + LabelsType, + LengthsType, + LogprobsType, + MaskType, + NeuralType, + SpectrogramType, +) +from nemo.utils import logging + +try: + from sacrebleu import corpus_bleu + from nemo.collections.nlp.modules.common import TokenClassifier + from nemo.collections.nlp.modules.common.lm_utils import get_transformer + from nemo.collections.nlp.modules.common.transformer import BeamSearchSequenceGenerator, TransformerEncoder + + NLP_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + NLP_AVAILABLE = False + logging.warning("Could not import NeMo NLP collection which is required for speech translation model.") + +__all__ = ['EncDecTransfModelBPE'] + + +def lens_to_mask(lens, max_length): + batch_size = lens.shape[0] + mask = torch.arange(max_length).repeat(batch_size, 1).to(lens.device) < lens[:, None] + return mask + + +class EncDecTransfModelBPE(ASRModel, ExportableEncDecModel, ASRBPEMixin): + """Base class for encoder decoder CTC-based models.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + + if 'tokenizer' not in cfg: + raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") + + # Setup the tokenizer + self._setup_tokenizer(cfg.tokenizer) + + super().__init__(cfg=cfg, trainer=trainer) + + # Setup audio preprocessor + self.preprocessor = EncDecTransfModelBPE.from_config_dict(self.cfg.preprocessor) + + # Setup audio encoder + self.encoder = EncDecTransfModelBPE.from_config_dict(self.cfg.encoder) + + # Add projection layer if encoder and decoder differ in hidden size + if self.cfg.encoder['d_model'] != self.cfg.transf_decoder['hidden_size']: + self.adapter = torch.nn.Linear(self.cfg.encoder['d_model'], self.cfg.transf_decoder['hidden_size']) + else: + self.adapter = torch.nn.Identity() + + transf_encoder_cfg_dict = OmegaConf.to_container(cfg.get('transf_encoder')) + + # Whether to add Transformer Encoder block between Conformer and Transformer Decoder + self.use_transf_encoder = False + if transf_encoder_cfg_dict['num_layers'] > 0: + self.use_transf_encoder = True + + self.transf_encoder = TransformerEncoder( + num_layers=transf_encoder_cfg_dict['num_layers'], + hidden_size=transf_encoder_cfg_dict['hidden_size'], + inner_size=transf_encoder_cfg_dict['inner_size'], + mask_future=False, + num_attention_heads=transf_encoder_cfg_dict['num_attention_heads'], + attn_score_dropout=transf_encoder_cfg_dict['attn_score_dropout'], + attn_layer_dropout=transf_encoder_cfg_dict['attn_layer_dropout'], + ffn_dropout=transf_encoder_cfg_dict['ffn_dropout'], + pre_ln=transf_encoder_cfg_dict.get('pre_ln', True), + pre_ln_final_layer_norm=transf_encoder_cfg_dict.get('pre_ln_final_layer_norm', True), + ) + std_init_range = 1 / transf_encoder_cfg_dict['hidden_size'] ** 0.5 + self.transf_encoder.apply(lambda module: transformer_weights_init(module, std_init_range)) + + transf_decoder_cfg_dict = OmegaConf.to_container(cfg.get('transf_decoder')) + + # Transformer decoder + vocab_size = 8 * ceil(self.tokenizer.vocab_size / 8) + transf_decoder_cfg_dict['vocab_size'] = vocab_size + library = transf_decoder_cfg_dict.pop('library', 'nemo') + model_name = transf_decoder_cfg_dict.pop('model_name', None) + pretrained = transf_decoder_cfg_dict.pop('pretrained', False) + self.transf_decoder = get_transformer( + library=library, + model_name=model_name, + pretrained=pretrained, + config_dict=transf_decoder_cfg_dict, + encoder=False, + pre_ln_final_layer_norm=transf_decoder_cfg_dict.get("pre_ln_final_layer_norm", False), + ) + + self.log_softmax = TokenClassifier( + hidden_size=self.transf_decoder.hidden_size, + num_classes=vocab_size, + activation=self.cfg.head.activation, + log_softmax=self.cfg.head.log_softmax, + dropout=self.cfg.head.dropout, + use_transformer_init=self.cfg.head.use_transformer_init, + ) + self.log_softmax.mlp.layer0.weight = self.transf_decoder.embedding.token_embedding.weight + std_init_range = 1 / self.transf_decoder.hidden_size ** 0.5 + self.transf_decoder.apply(lambda module: transformer_weights_init(module, std_init_range)) + self.log_softmax.apply(lambda module: transformer_weights_init(module, std_init_range)) + + # Beam Search decoding + self.beam_search = BeamSearchSequenceGenerator( + embedding=self.transf_decoder.embedding, + decoder=self.transf_decoder.decoder, + log_softmax=self.log_softmax, + max_sequence_length=self.transf_decoder.max_sequence_length, + beam_size=self.cfg.beam_search.beam_size, + bos=self.tokenizer.bos_id, + pad=self.tokenizer.pad_id, + eos=self.tokenizer.eos_id, + len_pen=self.cfg.beam_search.len_pen, + max_delta_length=self.cfg.beam_search.max_generation_delta, + ) + + # Define autoregressive CE loss + self.transf_loss = SmoothedCrossEntropyLoss( + pad_id=self.tokenizer.pad_id, label_smoothing=self.cfg.label_smoothing + ) + + if hasattr(self.cfg, 'spec_augment') and self.cfg.spec_augment is not None: + self.spec_augmentation = EncDecTransfModelBPE.from_config_dict(self.cfg.spec_augment) + else: + self.spec_augmentation = None + + self.val_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True) + + @torch.no_grad() + def translate( + self, + paths2audio_files: List[str], + batch_size: int = 4, + logprobs: bool = False, + return_hypotheses: bool = False, + ) -> List[str]: + hypotheses = self.transcribe(paths2audio_files, batch_size, logprobs, return_hypotheses) + return hypotheses + + @torch.no_grad() + def transcribe( + self, + paths2audio_files: List[str], + batch_size: int = 4, + logprobs: bool = False, + return_hypotheses: bool = False, + ) -> List[str]: + """ + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + Args: + paths2audio_files: (a list) of paths to audio files. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + logprobs: (bool) pass True to get log probabilities instead of transcripts. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + Returns: + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + if paths2audio_files is None or len(paths2audio_files) == 0: + return {} + + if return_hypotheses and logprobs: + raise ValueError( + "Either `return_hypotheses` or `logprobs` can be True at any given time." + "Returned hypotheses will contain the logprobs." + ) + + # We will store transcriptions here + hypotheses = [] + + # Model's mode and device + mode = self.training + device = next(self.parameters()).device + dither_value = self.preprocessor.featurizer.dither + pad_to_value = self.preprocessor.featurizer.pad_to + + try: + self.preprocessor.featurizer.dither = 0.0 + self.preprocessor.featurizer.pad_to = 0 + # Switch model to evaluation mode + self.eval() + # Freeze the encoder and decoder modules + self.encoder.freeze() + self.transf_decoder.freeze() + logging_level = logging.get_verbosity() + logging.set_verbosity(logging.WARNING) + # Work in tmp directory - will store manifest file there + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp: + for audio_file in paths2audio_files: + entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing'} + fp.write(json.dumps(entry) + '\n') + + config = {'paths2audio_files': paths2audio_files, 'batch_size': batch_size, 'temp_dir': tmpdir} + + temporary_datalayer = self._setup_transcribe_dataloader(config) + for test_batch in tqdm(temporary_datalayer, desc="Transcribing"): + log_probs, encoded_len, enc_states, enc_mask = self.forward( + input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device) + ) + + beam_hypotheses = ( + self.beam_search( + encoder_hidden_states=enc_states, encoder_input_mask=enc_mask, return_beam_scores=False + ) + .detach() + .cpu() + .numpy() + ) + beam_hypotheses = [self.tokenizer.ids_to_text(hyp) for hyp in beam_hypotheses] + + if return_hypotheses: + # dump log probs per file + for idx in range(logits.shape[0]): + current_hypotheses[idx].y_sequence = logits[idx][: logits_len[idx]] + + hypotheses += beam_hypotheses + + del test_batch, log_probs, encoded_len, enc_states, enc_mask + finally: + # set mode back to its original value + self.train(mode=mode) + self.preprocessor.featurizer.dither = dither_value + self.preprocessor.featurizer.pad_to = pad_to_value + if mode is True: + self.encoder.unfreeze() + self.transf_decoder.unfreeze() + logging.set_verbosity(logging_level) + + return hypotheses + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + + dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( + config=config, + local_rank=self.local_rank, + global_rank=self.global_rank, + world_size=self.world_size, + tokenizer=self.tokenizer, + preprocessor_cfg=self.cfg.get("preprocessor", None), + ) + + if dataset is None: + return None + + shuffle = config['shuffle'] + if config.get('is_tarred', False): + shuffle = False + + if hasattr(dataset, 'collate_fn'): + collate_fn = dataset.collate_fn + else: + collate_fn = dataset.datasets[0].collate_fn + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[DictConfig]): + + # create audio-only data loader + self._update_dataset_config(dataset_name='train', config=train_data_config) + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + # Need to set this because if using an IterableDataset, the length of the + # dataloader is the total number of samples rather than the number of batches, + # and this messes up the tqdm progress bar. So we set the number of steps manually + # (to the correct number) to fix this. + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, + # i.e. <= # training batches, and don't change it. Otherwise, adjust + # batches accordingly if it's a float (including 1.0). + if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "training batches will be used. Please set the trainer and rebuild the dataset." + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + Args: + val_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + Args: + test_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + input_signal_eltype = AudioSignal() + return { + "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), + "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), + "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "transcript": NeuralType(('B', 'T'), LabelsType(), optional=True), + "transcript_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "sample_id": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "transf_log_probs": NeuralType(('B', 'T', 'D'), LogprobsType()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "encoder_states": NeuralType(('B', 'T', 'D'), ChannelType()), + "encoder_mask": NeuralType(('B', 'T'), MaskType()), + } + + @typecheck() + def forward( + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, + transcript=None, + transcript_length=None, + ): + """ + Forward pass of the model. + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + processed_signal: Tensor that represents a batch of processed audio signals, + of shape (B, D, T) that has undergone processing via some DALI preprocessor. + processed_signal_length: Vector of length B, that contains the individual lengths of the + processed audio sequences. + Returns: + A tuple of 3 elements - + 1) The log probabilities tensor of shape [B, T, D]. + 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. + 3) The greedy token predictions of the model of shape [B, T] (via argmax) + """ + has_input_signal = input_signal is not None and input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) == False: + raise ValueError( + f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + + if not has_processed_signal: + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, length=input_signal_length + ) + + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + + enc_states = encoded.permute(0, 2, 1) + enc_states = self.adapter(enc_states) + enc_mask = lens_to_mask(encoded_len, enc_states.shape[1]).to(enc_states.dtype) + if self.use_transf_encoder: + enc_states = self.transf_encoder(encoder_states=enc_states, encoder_mask=enc_mask) + + transf_log_probs = None + if transcript is not None: + dec_mask = lens_to_mask(transcript_length, transcript.shape[1]).to(transcript.dtype) + dec_states = self.transf_decoder( + input_ids=transcript, decoder_mask=dec_mask, encoder_embeddings=enc_states, encoder_mask=enc_mask + ) + transf_log_probs = self.log_softmax(hidden_states=dec_states) + + return transf_log_probs, encoded_len, enc_states, enc_mask + + def compute_audio_loss(self, batch): + + if batch is None: + return 0 + + signal, signal_len, transcript, transcript_len = batch + input_ids, labels = transcript[:, :-1], transcript[:, 1:] + + transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( + input_signal=signal, + input_signal_length=signal_len, + transcript=input_ids, + transcript_length=transcript_len, + ) + + transf_loss = self.transf_loss(log_probs=transf_log_probs, labels=labels) + + return transf_loss + + # PTL-specific methods + def training_step(self, batch, batch_nb): + + audio_loss = self.compute_audio_loss(batch) + + tensorboard_logs = { + 'train_loss': audio_loss, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + } + + return {'loss': audio_loss, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): + signal, signal_len, transcript, transcript_len = batch + input_ids, labels = transcript[:, :-1], transcript[:, 1:] + + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( + processed_signal=signal, + processed_signal_length=signal_len, + transcript=input_ids, + transcript_length=transcript_len, + ) + else: + transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( + input_signal=signal, + input_signal_length=signal_len, + transcript=input_ids, + transcript_length=transcript_len, + ) + + beam_hypotheses = self.beam_search( + encoder_hidden_states=enc_states, encoder_input_mask=enc_mask, return_beam_scores=False + ) + transf_loss = self.transf_loss(log_probs=transf_log_probs, labels=labels) + + ground_truths = [self.tokenizer.ids_to_text(sent) for sent in transcript.detach().cpu().tolist()] + translations = [self.tokenizer.ids_to_text(sent) for sent in beam_hypotheses.detach().cpu().tolist()] + + self.val_loss(loss=transf_loss, num_measurements=transf_log_probs.shape[0] * transf_log_probs.shape[1]) + + return {f'{eval_mode}_loss': transf_loss, 'translations': translations, 'ground_truths': ground_truths} + + def test_step(self, batch, batch_idx, dataloader_idx=0): + return self.validation_step(batch, batch_idx, dataloader_idx, eval_mode="test") + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0, eval_mode: str = "val"): + """ + Called at the end of validation to aggregate outputs. + :param outputs: list of individual outputs of each validation step. + """ + if not outputs: + return + + if isinstance(outputs[0], dict): + outputs = [outputs] + + for output in outputs: + eval_loss = getattr(self, 'val_loss').compute() + translations = list(itertools.chain(*[x['translations'] for x in output])) + ground_truths = list(itertools.chain(*[x['ground_truths'] for x in output])) + + # Gather translations and ground truths from all workers + tr_and_gt = [None for _ in range(self.world_size)] + # we also need to drop pairs where ground truth is an empty string + if self.world_size > 1: + dist.all_gather_object( + tr_and_gt, [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != ''] + ) + else: + tr_and_gt[0] = [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != ''] + + if self.global_rank == 0: + _translations = [] + _ground_truths = [] + for rank in range(0, self.world_size): + _translations += [t for (t, g) in tr_and_gt[rank]] + _ground_truths += [g for (t, g) in tr_and_gt[rank]] + + sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="13a") + sb_score = sacre_bleu.score * self.world_size + + wer_scores, wer_words = 0, 0 + for h, r in zip(_translations, _ground_truths): + wer_words += len(r.split()) + wer_scores += editdistance.eval(h.split(), r.split()) + wer_score = 1.0 * wer_scores * self.world_size / wer_words + + else: + sb_score = 0.0 + wer_score = 0.0 + + self.log(f"{eval_mode}_loss", eval_loss, sync_dist=True) + self.log(f"{eval_mode}_sacreBLEU", sb_score, sync_dist=True) + self.log(f"{eval_mode}_WER", wer_score, sync_dist=True) + self.val_loss.reset() + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + return self.multi_validation_epoch_end(outputs, dataloader_idx, eval_mode="test") + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + Returns: + A pytorch DataLoader for the given audio file(s). + """ + batch_size = min(config['batch_size'], len(config['paths2audio_files'])) + dl_config = { + 'manifest_filepath': os.path.join(config['temp_dir'], 'manifest.json'), + 'sample_rate': self.preprocessor._sample_rate, + 'batch_size': batch_size, + 'trim_silence': False, + 'shuffle': False, + 'num_workers': min(batch_size, os.cpu_count() - 1), + 'pin_memory': True, + } + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 8f429c25806d..66934928fc79 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -47,6 +47,7 @@ from nemo.core.classes.mixins import AccessMixin, adapter_mixins from nemo.core.classes.module import NeuralModule from nemo.core.neural_types import AcousticEncodedRepresentation, ChannelType, LengthsType, NeuralType, SpectrogramType +from nemo.utils import logging __all__ = ['ConformerEncoder'] @@ -778,6 +779,10 @@ def _calc_context_sizes( return att_context_size_all, att_context_size_all[0], att_context_probs, conv_context_size def set_default_att_context_size(self, att_context_size): + if att_context_size not in self.att_context_size_all: + logging.warning( + f"att_context_size={att_context_size} is not among the list of the supported look-aheads: {self.att_context_size_all}" + ) self.att_context_size = att_context_size def setup_streaming_params( diff --git a/nemo/collections/asr/parts/submodules/causal_convs.py b/nemo/collections/asr/parts/submodules/causal_convs.py index c6251690b1b1..32f08a8d2feb 100644 --- a/nemo/collections/asr/parts/submodules/causal_convs.py +++ b/nemo/collections/asr/parts/submodules/causal_convs.py @@ -130,13 +130,16 @@ def __init__( def update_cache(self, x, cache=None): if cache is None: new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) + next_cache = cache else: new_x = F.pad(x, pad=(0, self._right_padding)) new_x = torch.cat([cache, new_x], dim=-1) if self.cache_drop_size > 0: - x = x[:, :, : -self.cache_drop_size] - cache = torch.cat([cache[:, :, x.size(-1) :], x], dim=-1) - return new_x, cache + next_cache = new_x[:, :, : -self.cache_drop_size] + else: + next_cache = new_x + next_cache = next_cache[:, :, -cache.size(-1) :] + return new_x, next_cache def forward(self, x, cache=None): x, cache = self.update_cache(x, cache=cache) diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index a64eded97208..1f29a511fc9c 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -16,12 +16,13 @@ from typing import List, Optional import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from nemo.collections.asr.parts.utils import rnnt_utils -from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureMixin, ConfidenceMethodConfig +from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureConfig, ConfidenceMeasureMixin from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import HypothesisType, LengthsType, LogprobsType, NeuralType +from nemo.utils import logging def pack_hypotheses(hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor,) -> List[rnnt_utils.Hypothesis]: @@ -70,31 +71,32 @@ class GreedyCTCInfer(Typing, ConfidenceMeasureMixin): preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding. When set to true, the Hypothesis will contain the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of floats. - confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence_measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - 'entropy' for using a normalized entropy of a log-likelihood vector. - entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + entropy_type: Which type of entropy to use (str). Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) entropy_norm: A mapping of the entropy value to the interval [0,1]. @@ -128,7 +130,7 @@ def __init__( preserve_alignments: bool = False, compute_timestamps: bool = False, preserve_frame_confidence: bool = False, - confidence_method_cfg: Optional[DictConfig] = None, + confidence_measure_cfg: Optional[DictConfig] = None, ): super().__init__() @@ -138,8 +140,8 @@ def __init__( self.compute_timestamps = compute_timestamps | preserve_frame_confidence self.preserve_frame_confidence = preserve_frame_confidence - # set confidence calculation method - self._init_confidence_measure(confidence_method_cfg) + # set confidence calculation measure + self._init_confidence_measure(confidence_measure_cfg) @typecheck() def forward( @@ -251,4 +253,27 @@ class GreedyCTCInferConfig: preserve_alignments: bool = False compute_timestamps: bool = False preserve_frame_confidence: bool = False - confidence_method_cfg: Optional[ConfidenceMethodConfig] = None + confidence_measure_cfg: Optional[ConfidenceMeasureConfig] = ConfidenceMeasureConfig() + confidence_method_cfg: str = "DEPRECATED" + + def __post_init__(self): + # OmegaConf.structured ensures that post_init check is always executed + self.confidence_measure_cfg = OmegaConf.structured( + self.confidence_measure_cfg + if isinstance(self.confidence_measure_cfg, ConfidenceMeasureConfig) + else ConfidenceMeasureConfig(**self.confidence_measure_cfg) + ) + if self.confidence_method_cfg != "DEPRECATED": + logging.warning( + "`confidence_method_cfg` is deprecated and will be removed in the future. " + "Please use `confidence_measure_cfg` instead." + ) + + # TODO (alaptev): delete the following two lines sometime in the future + logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.") + # OmegaConf.structured ensures that post_init check is always executed + self.confidence_measure_cfg = OmegaConf.structured( + self.confidence_method_cfg + if isinstance(self.confidence_method_cfg, ConfidenceMeasureConfig) + else ConfidenceMeasureConfig(**self.confidence_method_cfg) + ) diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index a0253524419e..6a866a617f35 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -377,12 +377,6 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None): scores += d_mask - attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) - attn = self.dropout(attn) - # (batch, head, time, 2w + 1) - - out = self.sliding_chunks_matmul_pv(attn, v, w).reshape(n_batch, -1, self.h * self.d_k) - if self.global_tokens > 0: # create q, k, v for global attn @@ -426,21 +420,34 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None): is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero, ).transpose(1, 2) - global_key_attn = torch.softmax(global_key_attn, dim=-1).masked_fill(mask, 0.0) - global_key_attn = self.dropout(global_key_attn) + # concat to local_attn_probs + # (batch, time, head, max_num_global_attn_indices + 2*w) + scores = torch.cat((global_key_attn, scores), dim=-1) - # compute outputs for global attention from all tokens to global - # (batch, time, head x head_dim) - out_all_to_global = self._compute_out_all_to_global( - value=global_v, - attn_probs=global_key_attn, + # free memory + del global_key_attn + + attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) + p_attn = self.dropout(attn) + # (batch, head, time, 2w + 1) + + if self.global_tokens > 0: + # compute sum of global and local attn + out = self._compute_attn_output_with_global_indices( + value=v, + attn_probs=p_attn, max_num_global_attn_indices=max_num_global_attn_indices, is_index_global_attn_nonzero=is_index_global_attn_nonzero, is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero, + w=w, ) + else: + # compute local attn only + out = self.sliding_chunks_matmul_pv(p_attn, v, w) + + out = out.reshape(n_batch, -1, self.h * self.d_k)[:, :T] - # compute outputs for global attention from global tokens to all - # (batch, max_num_global_attn_indices, head x head_dim) + if self.global_tokens > 0: out_global_to_all = self._compute_out_global_to_all( query=global_q, key=global_k, @@ -452,11 +459,11 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None): is_index_masked=mask, ) - out += out_all_to_global + # overwrite values with global attention + out[is_index_global_attn_nonzero] = out_global_to_all - out[is_index_global_attn_nonzero] += out_global_to_all + ret = self.linear_out(out) - ret = self.linear_out(out.reshape(n_batch, -1, self.h * self.d_k)[:, :T]) if cache is None: return ret else: @@ -544,16 +551,17 @@ def _compute_global_key_attn( return attn_probs_from_global_key - def _compute_out_all_to_global( + def _compute_attn_output_with_global_indices( self, value: torch.Tensor, attn_probs: torch.Tensor, max_num_global_attn_indices: int, is_index_global_attn_nonzero: tuple, is_local_index_global_attn_nonzero: tuple, + w: int, ) -> torch.Tensor: """ - Compute the attention output of all tokens attending to global. + Compute the attention output with global indices. Args: value (torch.Tensor): (batch, head, time, head_dim) The value vectors for global attention. @@ -561,7 +569,7 @@ def _compute_out_all_to_global( max_num_global_attn_indices (int): Maximum number of global attention indices in the batch. is_index_global_attn_nonzero (tuple): Indices of global attention (non-zero elements). is_local_index_global_attn_nonzero (tuple): Non-padding values within global attention indices. - + w (int): Local context size Returns: torch.Tensor: (batch, time, head x head_dim) The attention output of all tokens attending to global. """ @@ -573,12 +581,22 @@ def _compute_out_all_to_global( value_vectors_only_global = value.new_zeros(batch_size, max_num_global_attn_indices, self.h, self.d_k) value_vectors_only_global[is_local_index_global_attn_nonzero] = value[is_index_global_attn_nonzero] + # cut local attn probs to global only + attn_probs_only_global = attn_probs.narrow(-1, 0, max_num_global_attn_indices) # compute attn output only global - out_all_to_global = torch.matmul(attn_probs, value_vectors_only_global.transpose(1, 2)).transpose(1, 2) + attn_output_only_global = torch.matmul( + attn_probs_only_global.clone(), value_vectors_only_global.transpose(1, 2).clone() + ).transpose(1, 2) + + # reshape attn probs + attn_probs_without_global = attn_probs.narrow( + -1, max_num_global_attn_indices, attn_probs.size(-1) - max_num_global_attn_indices + ).contiguous() - out_all_to_global = out_all_to_global.reshape(batch_size, time, -1) + # compute attn output with global + attn_output_without_global = self.sliding_chunks_matmul_pv(attn_probs_without_global, value.transpose(1, 2), w) - return out_all_to_global + return attn_output_only_global + attn_output_without_global def _compute_out_global_to_all( self, diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index ac10e54bb249..dfa3ac27854b 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -31,11 +31,11 @@ import numpy as np import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from nemo.collections.asr.modules import rnnt_abstract from nemo.collections.asr.parts.utils import rnnt_utils -from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureMixin, ConfidenceMethodConfig +from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureConfig, ConfidenceMeasureMixin from nemo.collections.common.parts.rnn import label_collate from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import AcousticEncodedRepresentation, ElementType, HypothesisType, LengthsType, NeuralType @@ -96,34 +96,32 @@ class _GreedyRNNTInfer(Typing, ConfidenceMeasureMixin): The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. U is the number of target tokens for the current timestep Ti. - confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence_measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - - 'entropy' for using normalized entropy of a log-likelihood vector. + - 'entropy' for using a normalized entropy of a log-likelihood vector. - entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + entropy_type: Which type of entropy to use (str). Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. If the temperature α is provided, - the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) entropy_norm: A mapping of the entropy value to the interval [0,1]. @@ -156,7 +154,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, - confidence_method_cfg: Optional[DictConfig] = None, + confidence_measure_cfg: Optional[DictConfig] = None, ): super().__init__() self.decoder = decoder_model @@ -168,8 +166,8 @@ def __init__( self.preserve_alignments = preserve_alignments self.preserve_frame_confidence = preserve_frame_confidence - # set confidence calculation method - self._init_confidence_measure(confidence_method_cfg) + # set confidence calculation measure + self._init_confidence_measure(confidence_measure_cfg) def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @@ -265,31 +263,32 @@ class GreedyRNNTInfer(_GreedyRNNTInfer): The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. U is the number of target tokens for the current timestep Ti. - confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence_measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - - 'entropy' for using normalized entropy of a log-likelihood vector. + - 'entropy' for using a normalized entropy of a log-likelihood vector. - entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + entropy_type: Which type of entropy to use (str). Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) entropy_norm: A mapping of the entropy value to the interval [0,1]. @@ -306,7 +305,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, - confidence_method_cfg: Optional[DictConfig] = None, + confidence_measure_cfg: Optional[DictConfig] = None, ): super().__init__( decoder_model=decoder_model, @@ -315,7 +314,7 @@ def __init__( max_symbols_per_step=max_symbols_per_step, preserve_alignments=preserve_alignments, preserve_frame_confidence=preserve_frame_confidence, - confidence_method_cfg=confidence_method_cfg, + confidence_measure_cfg=confidence_measure_cfg, ) @typecheck() @@ -503,31 +502,32 @@ class GreedyBatchedRNNTInfer(_GreedyRNNTInfer): The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. U is the number of target tokens for the current timestep Ti. - confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence_measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - - 'entropy' for using normalized entropy of a log-likelihood vector. + - 'entropy' for using a normalized entropy of a log-likelihood vector. - entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + entropy_type: Which type of entropy to use (str). Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) entropy_norm: A mapping of the entropy value to the interval [0,1]. @@ -544,7 +544,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, - confidence_method_cfg: Optional[DictConfig] = None, + confidence_measure_cfg: Optional[DictConfig] = None, ): super().__init__( decoder_model=decoder_model, @@ -553,7 +553,7 @@ def __init__( max_symbols_per_step=max_symbols_per_step, preserve_alignments=preserve_alignments, preserve_frame_confidence=preserve_frame_confidence, - confidence_method_cfg=confidence_method_cfg, + confidence_measure_cfg=confidence_measure_cfg, ) # Depending on availability of `blank_as_pad` support @@ -1478,29 +1478,34 @@ class GreedyMultiblankRNNTInfer(GreedyRNNTInfer): The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. U is the number of target tokens for the current timestep Ti. - confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence_measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - - 'entropy' for using normalized entropy of a log-likelihood vector. - entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + - 'entropy' for using a normalized entropy of a log-likelihood vector. + + entropy_type: Which type of entropy to use (str). Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) + entropy_norm: A mapping of the entropy value to the interval [0,1]. Supported values: - 'lin' for using the linear mapping. @@ -1516,7 +1521,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, - confidence_method_cfg: Optional[DictConfig] = None, + confidence_measure_cfg: Optional[DictConfig] = None, ): super().__init__( decoder_model=decoder_model, @@ -1525,7 +1530,7 @@ def __init__( max_symbols_per_step=max_symbols_per_step, preserve_alignments=preserve_alignments, preserve_frame_confidence=preserve_frame_confidence, - confidence_method_cfg=confidence_method_cfg, + confidence_measure_cfg=confidence_measure_cfg, ) self.big_blank_durations = big_blank_durations self._SOS = blank_index - len(big_blank_durations) @@ -1677,29 +1682,34 @@ class GreedyBatchedMultiblankRNNTInfer(GreedyBatchedRNNTInfer): The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. U is the number of target tokens for the current timestep Ti. - confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence_measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - - 'entropy' for using normalized entropy of a log-likelihood vector. - entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + - 'entropy' for using a normalized entropy of a log-likelihood vector. + + entropy_type: Which type of entropy to use (str). Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) + entropy_norm: A mapping of the entropy value to the interval [0,1]. Supported values: - 'lin' for using the linear mapping. @@ -1715,7 +1725,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, - confidence_method_cfg: Optional[DictConfig] = None, + confidence_measure_cfg: Optional[DictConfig] = None, ): super().__init__( decoder_model=decoder_model, @@ -1724,7 +1734,7 @@ def __init__( max_symbols_per_step=max_symbols_per_step, preserve_alignments=preserve_alignments, preserve_frame_confidence=preserve_frame_confidence, - confidence_method_cfg=confidence_method_cfg, + confidence_measure_cfg=confidence_measure_cfg, ) self.big_blank_durations = big_blank_durations @@ -2193,7 +2203,31 @@ class GreedyRNNTInferConfig: max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False - confidence_method_cfg: Optional[ConfidenceMethodConfig] = None + confidence_measure_cfg: Optional[ConfidenceMeasureConfig] = ConfidenceMeasureConfig() + confidence_method_cfg: str = "DEPRECATED" + + def __post_init__(self): + # OmegaConf.structured ensures that post_init check is always executed + self.confidence_measure_cfg = OmegaConf.structured( + self.confidence_measure_cfg + if isinstance(self.confidence_measure_cfg, ConfidenceMeasureConfig) + else ConfidenceMeasureConfig(**self.confidence_measure_cfg) + ) + if self.confidence_method_cfg != "DEPRECATED": + logging.warning( + "`confidence_method_cfg` is deprecated and will be removed in the future. " + "Please use `confidence_measure_cfg` instead." + ) + + # TODO (alaptev): delete the following two lines sometime in the future + logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.") + # OmegaConf.structured ensures that post_init check is always executed + self.confidence_measure_cfg = OmegaConf.structured( + self.confidence_method_cfg + if isinstance(self.confidence_method_cfg, ConfidenceMeasureConfig) + else ConfidenceMeasureConfig(**self.confidence_method_cfg) + ) + self.confidence_method_cfg = "DEPRECATED" @dataclass @@ -2201,7 +2235,31 @@ class GreedyBatchedRNNTInferConfig: max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False - confidence_method_cfg: Optional[ConfidenceMethodConfig] = None + confidence_measure_cfg: Optional[ConfidenceMeasureConfig] = ConfidenceMeasureConfig() + confidence_method_cfg: str = "DEPRECATED" + + def __post_init__(self): + # OmegaConf.structured ensures that post_init check is always executed + self.confidence_measure_cfg = OmegaConf.structured( + self.confidence_measure_cfg + if isinstance(self.confidence_measure_cfg, ConfidenceMeasureConfig) + else ConfidenceMeasureConfig(**self.confidence_measure_cfg) + ) + if self.confidence_method_cfg != "DEPRECATED": + logging.warning( + "`confidence_method_cfg` is deprecated and will be removed in the future. " + "Please use `confidence_measure_cfg` instead." + ) + + # TODO (alaptev): delete the following two lines sometime in the future + logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.") + # OmegaConf.structured ensures that post_init check is always executed + self.confidence_measure_cfg = OmegaConf.structured( + self.confidence_method_cfg + if isinstance(self.confidence_method_cfg, ConfidenceMeasureConfig) + else ConfidenceMeasureConfig(**self.confidence_method_cfg) + ) + self.confidence_method_cfg = "DEPRECATED" class GreedyTDTInfer(_GreedyRNNTInfer): @@ -2230,29 +2288,34 @@ class GreedyTDTInfer(_GreedyRNNTInfer): The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. U is the number of target tokens for the current timestep Ti. - confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence_measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - - 'entropy' for using normalized entropy of a log-likelihood vector. - entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + - 'entropy' for using a normalized entropy of a log-likelihood vector. + + entropy_type: Which type of entropy to use (str). Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) + entropy_norm: A mapping of the entropy value to the interval [0,1]. Supported values: - 'lin' for using the linear mapping. @@ -2268,7 +2331,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, - confidence_method_cfg: Optional[DictConfig] = None, + confidence_measure_cfg: Optional[DictConfig] = None, ): super().__init__( decoder_model=decoder_model, @@ -2277,7 +2340,7 @@ def __init__( max_symbols_per_step=max_symbols_per_step, preserve_alignments=preserve_alignments, preserve_frame_confidence=preserve_frame_confidence, - confidence_method_cfg=confidence_method_cfg, + confidence_measure_cfg=confidence_measure_cfg, ) self.durations = durations @@ -2481,29 +2544,34 @@ class GreedyBatchedTDTInfer(_GreedyRNNTInfer): The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. U is the number of target tokens for the current timestep Ti. - confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence_measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. - name: The method name (str). + + name: The measure name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - - 'entropy' for using normalized entropy of a log-likelihood vector. - entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + - 'entropy' for using a normalized entropy of a log-likelihood vector. + + entropy_type: Which type of entropy to use (str). Used if confidence_measure_cfg.name is set to `entropy`. Supported values: - - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). - Note that for this entropy, the temperature should comply the following inequality: - 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/Tsallis_entropy - - 'renui' for the Rényi entropy. + - 'renyi' for the Rényi entropy. Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), where α is a parameter. When α == 1, it works like the Gibbs entropy. More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy - temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. - When the temperature equals one, scaling is not applied to 'max_prob', + + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) + entropy_norm: A mapping of the entropy value to the interval [0,1]. Supported values: - 'lin' for using the linear mapping. @@ -2519,7 +2587,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, - confidence_method_cfg: Optional[DictConfig] = None, + confidence_measure_cfg: Optional[DictConfig] = None, ): super().__init__( decoder_model=decoder_model, @@ -2528,7 +2596,7 @@ def __init__( max_symbols_per_step=max_symbols_per_step, preserve_alignments=preserve_alignments, preserve_frame_confidence=preserve_frame_confidence, - confidence_method_cfg=confidence_method_cfg, + confidence_measure_cfg=confidence_measure_cfg, ) self.durations = durations diff --git a/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py b/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py new file mode 100644 index 000000000000..958195a4bb11 --- /dev/null +++ b/nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py @@ -0,0 +1,183 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import copy +import os +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +import texterrors +import torch +from omegaconf import open_dict + +from nemo.collections.asr.models import ASRModel, EncDecRNNTModel +from nemo.collections.asr.parts.utils.confidence_metrics import ( + auc_nt, + auc_pr, + auc_roc, + auc_yc, + ece, + nce, + save_confidence_hist, + save_custom_confidence_curve, + save_nt_curve, + save_pr_curve, + save_roc_curve, +) +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis + + +def get_correct_marks(r: Union[List[int], List[str]], h: Union[List[int], List[str]]) -> List[bool]: + """Get correct marks by aligning the reference text with a hypothesis. + + This method considers only insertions and substitutions as incorrect marks. + """ + return [ + a == b + for a, b in zip(*(texterrors.align_texts([str(rr) for rr in r], [str(hh) for hh in h], False)[:-1])) + if b != "" + ] + + +def get_token_targets_with_confidence(hyp: Hypothesis) -> List[Tuple[str, float]]: + return [(y, c) for y, c in zip(hyp.y_sequence, hyp.token_confidence)] + + +def get_word_targets_with_confidence(hyp: Hypothesis) -> List[Tuple[str, float]]: + return [(y, c) for y, c in zip(hyp.words, hyp.word_confidence)] + + +def run_confidence_benchmark( + model: ASRModel, + target_level: str, + filepaths: List[str], + reference_texts: List[str], + batch_size: int = 8, + num_workers: int = 4, + plot_dir: Optional[Union[str, Path]] = None, + autocast: Optional = None, +): + """Run benchmark and plot histograms and curves, if plot_dir is provided. + + Returns: + Dictionary with benchmark results of the following scheme: + `level: (auc_roc, auc_pr, auc_nt, nce, ece, auc_yc, std_yc, max_yc)` with `level` being 'token' or 'word'. + """ + draw_plot = plot_dir is not None + if isinstance(plot_dir, str): + plot_dir = Path(plot_dir) + is_rnnt = isinstance(model, EncDecRNNTModel) + + # setup autocast if necessary + if autocast is None: + + @contextlib.contextmanager + def autocast(): + yield + + # transcribe audio + with autocast(): + with torch.no_grad(): + transcriptions = model.transcribe( + paths2audio_files=filepaths, batch_size=batch_size, return_hypotheses=True, num_workers=num_workers + ) + if is_rnnt: + transcriptions = transcriptions[0] + + levels = [] + if target_level != "word": + levels.append("token") + if target_level != "token": + levels.append("word") + results = {} + for level in levels: + if level == "token": + targets_with_confidence = [get_token_targets_with_confidence(tran) for tran in transcriptions] + correct_marks = [ + get_correct_marks(model.tokenizer.text_to_ids(r), model.tokenizer.text_to_ids(h.text)) + for r, h in zip(reference_texts, transcriptions) + ] + else: # "word" + targets_with_confidence = [get_word_targets_with_confidence(tran) for tran in transcriptions] + correct_marks = [get_correct_marks(r.split(), h.words) for r, h in zip(reference_texts, transcriptions)] + + y_true, y_score = np.array( + [[f, p[1]] for cm, twc in zip(correct_marks, targets_with_confidence) for f, p in zip(cm, twc)] + ).T + # output scheme: yc.mean(), yc.max(), yc.std() or yc.mean(), yc.max(), yc.std(), (thresholds, yc) + result_yc = auc_yc(y_true, y_score, return_std_maximum=True, return_curve=draw_plot) + # output scheme: ece or ece, (thresholds, ece_curve) + results_ece = ece(y_true, y_score, return_curve=draw_plot) + results[level] = [ + auc_roc(y_true, y_score), + auc_pr(y_true, y_score), + auc_nt(y_true, y_score), + nce(y_true, y_score), + results_ece if isinstance(results_ece, float) else results_ece[0], + ] + list(result_yc[:3]) + + if draw_plot: + os.makedirs(plot_dir, exist_ok=True) + + mask_correct = y_true == 1 + y_score_correct = y_score[mask_correct] + y_score_incorrect = y_score[~mask_correct] + # histogram of the correct distribution + save_confidence_hist(y_score_correct, plot_dir, level + "_" + "hist_correct") + # histogram of the incorrect distribution + save_confidence_hist(y_score_incorrect, plot_dir, level + "_" + "hist_incorrect") + # AUC-ROC curve + save_roc_curve(y_true, y_score, plot_dir, level + "_" + "roc") + # AUC-PR curve + save_pr_curve(y_true, y_score, plot_dir, level + "_" + "pr") + # AUC-NT curve + save_nt_curve(y_true, y_score, plot_dir, level + "_" + "nt") + # AUC-YC curve + yc_thresholds, yc_values = result_yc[-1] + save_custom_confidence_curve( + yc_thresholds, + yc_values, + plot_dir, + level + "_" + "yc", + "Threshold", + "True positive rate − False Positive Rate", + ) + # ECE curve + ece_thresholds, ece_values = results_ece[-1] + ece_values /= max(ece_values) + save_custom_confidence_curve( + ece_thresholds, ece_values, plot_dir, level + "_" + "ece", "Threshold", "|Accuracy − Confidence score|" + ) + + return results + + +def apply_confidence_parameters(decoding_cfg, hp): + """Apply parameters from a parameter grid to a decoding config. + + Returns: + Updated decoding config. + """ + new_decoding_cfg = copy.deepcopy(decoding_cfg) + confidence_cfg_fields = ("aggregation", "exclude_blank") + confidence_measure_cfg_fields = ("name", "alpha", "entropy_type", "entropy_norm") + with open_dict(new_decoding_cfg): + for p, v in hp.items(): + if p in confidence_cfg_fields: + new_decoding_cfg.confidence_cfg[p] = v + elif p in confidence_measure_cfg_fields: + new_decoding_cfg.confidence_cfg.measure_cfg[p] = v + return new_decoding_cfg diff --git a/nemo/collections/asr/parts/utils/asr_confidence_utils.py b/nemo/collections/asr/parts/utils/asr_confidence_utils.py index 1387f6940b38..29c49529a509 100644 --- a/nemo/collections/asr/parts/utils/asr_confidence_utils.py +++ b/nemo/collections/asr/parts/utils/asr_confidence_utils.py @@ -18,46 +18,197 @@ from functools import partial from typing import List, Optional +import torch from omegaconf import DictConfig, OmegaConf from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.utils import logging + + +class ConfidenceMeasureConstants: + NAMES = ("max_prob", "entropy") + ENTROPY_TYPES = ("gibbs", "tsallis", "renyi") + ENTROPY_NORMS = ("lin", "exp") + + @classmethod + def print(cls): + return ( + cls.__name__ + + ": " + + str({"NAMES": cls.NAMES, "ENTROPY_TYPES": cls.ENTROPY_TYPES, "ENTROPY_NORMS": cls.ENTROPY_NORMS}) + ) + + +class ConfidenceConstants: + AGGREGATIONS = ("mean", "min", "max", "prod") + + @classmethod + def print(cls): + return cls.__name__ + ": " + str({"AGGREGATIONS": cls.AGGREGATIONS}) @dataclass -class ConfidenceMethodConfig: +class ConfidenceMeasureConfig: + """A Config which contains the measure name and settings to compute per-frame confidence scores. + + Args: + name: The measure name (str). + Supported values: + - 'max_prob' for using the maximum token probability as a confidence. + - 'entropy' for using a normalized entropy of a log-likelihood vector. + + entropy_type: Which type of entropy to use (str). + Used if confidence_measure_cfg.name is set to `entropy`. + Supported values: + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, + the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. + - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. + Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/Tsallis_entropy + - 'renyi' for the Rényi entropy. + Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy + + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', + and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) + + entropy_norm: A mapping of the entropy value to the interval [0,1]. + Supported values: + - 'lin' for using the linear mapping. + - 'exp' for using exponential mapping with linear shift. + """ + name: str = "entropy" entropy_type: str = "tsallis" - temperature: float = 0.33 + alpha: float = 0.33 entropy_norm: str = "exp" + temperature: str = "DEPRECATED" def __post_init__(self): - if self.name not in ("max_prob", "entropy"): - raise ValueError(f"`name` has to be one of the following: `max_prob`, `entropy`. Provided: {self.name}") - if self.entropy_type not in ("gibbs", "tsallis", "renui"): + if self.temperature != "DEPRECATED": + logging.warning( + "`temperature` is deprecated and will be removed in the future. Please use `alpha` instead." + ) + + # TODO (alaptev): delete the following two lines sometime in the future + logging.warning("Re-writing `alpha` with the value of `temperature`.") + # self.temperature has type str + self.alpha = float(self.temperature) + self.temperature = "DEPRECATED" + if self.name not in ConfidenceMeasureConstants.NAMES: raise ValueError( - f"`entropy_type` has to be one of the following: `gibbs`, `tsallis`, `renui`. Provided: {self.entropy_type}" + f"`name` must be one of the following: " + f"{'`' + '`, `'.join(ConfidenceMeasureConstants.NAMES) + '`'}. Provided: `{self.name}`" ) - if self.temperature <= 0.0: - raise ValueError(f"`temperature` has to be > 0. Provided: {self.temperature}") - if self.entropy_norm not in ("lin", "exp"): + if self.entropy_type not in ConfidenceMeasureConstants.ENTROPY_TYPES: raise ValueError( - f"`entropy_norm` has to be one of the following: `lin`, `exp`. Provided: {self.entropy_norm}" + f"`entropy_type` must be one of the following: " + f"{'`' + '`, `'.join(ConfidenceMeasureConstants.ENTROPY_TYPES) + '`'}. Provided: `{self.entropy_type}`" + ) + if self.alpha <= 0.0: + raise ValueError(f"`alpha` must be > 0. Provided: {self.alpha}") + if self.entropy_norm not in ConfidenceMeasureConstants.ENTROPY_NORMS: + raise ValueError( + f"`entropy_norm` must be one of the following: " + f"{'`' + '`, `'.join(ConfidenceMeasureConstants.ENTROPY_NORMS) + '`'}. Provided: `{self.entropy_norm}`" ) @dataclass class ConfidenceConfig: + """A config which contains the following key-value pairs related to confidence scores. + + Args: + preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores + generated during decoding. When set to true, the Hypothesis will contain + the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of floats. + preserve_token_confidence: Bool flag which preserves the history of per-token confidence scores + generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain + the non-null value for `token_confidence` in it. Here, `token_confidence` is a List of floats. + + The length of the list corresponds to the number of recognized tokens. + preserve_word_confidence: Bool flag which preserves the history of per-word confidence scores + generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain + the non-null value for `word_confidence` in it. Here, `word_confidence` is a List of floats. + + The length of the list corresponds to the number of recognized words. + exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded + from the `token_confidence`. + aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. + Valid options are `mean`, `min`, `max`, `prod`. + measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame + confidence scores. + + name: The measure name (str). + Supported values: + - 'max_prob' for using the maximum token probability as a confidence. + - 'entropy' for using a normalized entropy of a log-likelihood vector. + + entropy_type: Which type of entropy to use (str). Used if confidence_measure_cfg.name is set to `entropy`. + Supported values: + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, + the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. + - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. + Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/Tsallis_entropy + - 'renyi' for the Rényi entropy. + Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy + + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', + and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) + + entropy_norm: A mapping of the entropy value to the interval [0,1]. + Supported values: + - 'lin' for using the linear mapping. + - 'exp' for using exponential mapping with linear shift. + """ + preserve_frame_confidence: bool = False preserve_token_confidence: bool = False preserve_word_confidence: bool = False exclude_blank: bool = True aggregation: str = "min" - method_cfg: ConfidenceMethodConfig = ConfidenceMethodConfig() + measure_cfg: ConfidenceMeasureConfig = ConfidenceMeasureConfig() + method_cfg: str = "DEPRECATED" def __post_init__(self): - if self.aggregation not in ("mean", "min", "max", "prod"): + # OmegaConf.structured ensures that post_init check is always executed + self.measure_cfg = OmegaConf.structured( + self.measure_cfg + if isinstance(self.measure_cfg, ConfidenceMeasureConfig) + else ConfidenceMeasureConfig(**self.measure_cfg) + ) + if self.method_cfg != "DEPRECATED": + logging.warning( + "`method_cfg` is deprecated and will be removed in the future. Please use `measure_cfg` instead." + ) + + # TODO (alaptev): delete the following two lines sometime in the future + logging.warning("Re-writing `measure_cfg` with the value of `method_cfg`.") + # OmegaConf.structured ensures that post_init check is always executed + self.measure_cfg = OmegaConf.structured( + self.method_cfg + if isinstance(self.method_cfg, ConfidenceMeasureConfig) + else ConfidenceMeasureConfig(**self.method_cfg) + ) + self.method_cfg = "DEPRECATED" + if self.aggregation not in ConfidenceConstants.AGGREGATIONS: raise ValueError( - f"`aggregation` has to be one of the following: `mean`, `min`, `max`, `prod`. Provided: {self.aggregation}" + f"`aggregation` has to be one of the following: " + f"{'`' + '`, `'.join(ConfidenceMeasureConstants.AGGREGATIONS) + '`'}. Provided: `{self.aggregation}`" ) @@ -70,32 +221,32 @@ def get_confidence_measure_bank(): entropy_gibbs_exp: Gibbs entropy with exponential normalization entropy_tsallis_lin: Tsallis entropy with linear normalization entropy_tsallis_exp: Tsallis entropy with exponential normalization - entropy_renui_lin: Rényi entropy with linear normalization - entropy_renui_exp: Rényi entropy with exponential normalization + entropy_renyi_lin: Rényi entropy with linear normalization + entropy_renyi_exp: Rényi entropy with exponential normalization Returns: dictionary with lambda functions. """ # helper functions - # Gibbs entropy is implemented without temperature + # Gibbs entropy is implemented without alpha neg_entropy_gibbs = lambda x: (x.exp() * x).sum(-1) - neg_entropy_temperature = lambda x, t: (x * t).exp().sum(-1) - neg_entropy_temperature_gibbs = lambda x, t: ((x * t).exp() * x).sum(-1) + neg_entropy_alpha = lambda x, t: (x * t).exp().sum(-1) + neg_entropy_alpha_gibbs = lambda x, t: ((x * t).exp() * x).sum(-1) # too big for a lambda def entropy_tsallis_exp(x, v, t): exp_neg_max_ent = math.exp((1 - math.pow(v, 1 - t)) / (1 - t)) - return (((1 - neg_entropy_temperature(x, t)) / (1 - t)).exp() - exp_neg_max_ent) / (1 - exp_neg_max_ent) + return (((1 - neg_entropy_alpha(x, t)) / (1 - t)).exp() - exp_neg_max_ent) / (1 - exp_neg_max_ent) def entropy_gibbs_exp(x, v, t): exp_neg_max_ent = math.pow(v, -t * math.pow(v, 1 - t)) - return ((neg_entropy_temperature_gibbs(x, t) * t).exp() - exp_neg_max_ent) / (1 - exp_neg_max_ent) + return ((neg_entropy_alpha_gibbs(x, t) * t).exp() - exp_neg_max_ent) / (1 - exp_neg_max_ent) # use Gibbs entropies for Tsallis and Rényi with t == 1.0 entropy_gibbs_lin_baseline = lambda x, v: 1 + neg_entropy_gibbs(x) / math.log(v) entropy_gibbs_exp_baseline = lambda x, v: (neg_entropy_gibbs(x).exp() * v - 1) / (v - 1) # fill the measure bank confidence_measure_bank = {} - # Maximum probability measure is implemented without temperature + # Maximum probability measure is implemented without alpha confidence_measure_bank["max_prob"] = ( lambda x, v, t: (x.max(dim=-1)[0].exp() * v - 1) / (v - 1) if t == 1.0 @@ -104,7 +255,7 @@ def entropy_gibbs_exp(x, v, t): confidence_measure_bank["entropy_gibbs_lin"] = ( lambda x, v, t: entropy_gibbs_lin_baseline(x, v) if t == 1.0 - else 1 + neg_entropy_temperature_gibbs(x, t) / math.log(v) / math.pow(v, 1 - t) + else 1 + neg_entropy_alpha_gibbs(x, t) / math.log(v) / math.pow(v, 1 - t) ) confidence_measure_bank["entropy_gibbs_exp"] = ( lambda x, v, t: entropy_gibbs_exp_baseline(x, v) if t == 1.0 else entropy_gibbs_exp(x, v, t) @@ -112,20 +263,20 @@ def entropy_gibbs_exp(x, v, t): confidence_measure_bank["entropy_tsallis_lin"] = ( lambda x, v, t: entropy_gibbs_lin_baseline(x, v) if t == 1.0 - else 1 + (1 - neg_entropy_temperature(x, t)) / (math.pow(v, 1 - t) - 1) + else 1 + (1 - neg_entropy_alpha(x, t)) / (math.pow(v, 1 - t) - 1) ) confidence_measure_bank["entropy_tsallis_exp"] = ( lambda x, v, t: entropy_gibbs_exp_baseline(x, v) if t == 1.0 else entropy_tsallis_exp(x, v, t) ) - confidence_measure_bank["entropy_renui_lin"] = ( + confidence_measure_bank["entropy_renyi_lin"] = ( lambda x, v, t: entropy_gibbs_lin_baseline(x, v) if t == 1.0 - else 1 + neg_entropy_temperature(x, t).log2() / (t - 1) / math.log(v, 2) + else 1 + neg_entropy_alpha(x, t).log2() / (t - 1) / math.log(v, 2) ) - confidence_measure_bank["entropy_renui_exp"] = ( + confidence_measure_bank["entropy_renyi_exp"] = ( lambda x, v, t: entropy_gibbs_exp_baseline(x, v) if t == 1.0 - else (neg_entropy_temperature(x, t).pow(1 / (t - 1)) * v - 1) / (v - 1) + else (neg_entropy_alpha(x, t).pow(1 / (t - 1)) * v - 1) / (v - 1) ) return confidence_measure_bank @@ -160,48 +311,55 @@ class ConfidenceMeasureMixin(ABC): It initializes per-frame confidence measure. """ - def _init_confidence_measure(self, confidence_method_cfg: Optional[DictConfig] = None): + def _init_confidence_measure(self, confidence_measure_cfg: Optional[DictConfig] = None): """Initialize per-frame confidence measure from config. """ - if confidence_method_cfg is None: - confidence_method_cfg = OmegaConf.structured(ConfidenceMethodConfig()) + # OmegaConf.structured ensures that post_init check is always executed + confidence_measure_cfg = OmegaConf.structured( + ConfidenceMeasureConfig() + if confidence_measure_cfg is None + else ConfidenceMeasureConfig(**confidence_measure_cfg) + ) - # set confidence calculation method + # set confidence calculation measure # we suppose that self.blank_id == len(vocabulary) self.num_tokens = (self.blank_id if hasattr(self, "blank_id") else self._blank_index) + 1 - self.temperature = confidence_method_cfg.temperature + self.alpha = confidence_measure_cfg.alpha # init confidence measure bank self.confidence_measure_bank = get_confidence_measure_bank() - method = None + measure = None # construct measure_name measure_name = "" - if confidence_method_cfg.name == "max_prob": + if confidence_measure_cfg.name == "max_prob": measure_name = "max_prob" - elif confidence_method_cfg.name == "entropy": + elif confidence_measure_cfg.name == "entropy": measure_name = '_'.join( - [confidence_method_cfg.name, confidence_method_cfg.entropy_type, confidence_method_cfg.entropy_norm] + [confidence_measure_cfg.name, confidence_measure_cfg.entropy_type, confidence_measure_cfg.entropy_norm] ) else: - raise ValueError(f"Unsupported `confidence_method_cfg.name`: `{confidence_method_cfg.name}`") + raise ValueError(f"Unsupported `confidence_measure_cfg.name`: `{confidence_measure_cfg.name}`") if measure_name not in self.confidence_measure_bank: raise ValueError(f"Unsupported measure setup: `{measure_name}`") - method = partial(self.confidence_measure_bank[measure_name], v=self.num_tokens, t=self.temperature) - self._get_confidence = lambda x: method(x).tolist() + measure = partial(self.confidence_measure_bank[measure_name], v=self.num_tokens, t=self.alpha) + self._get_confidence = lambda x: measure(torch.nan_to_num(x)).tolist() class ConfidenceMixin(ABC): """Confidence Mixin class. - It initializes per-frame confidence measure. + It is responsible for confidence estimation method initialization and high-level confidence score calculation. """ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None): """Initialize confidence-related fields and confidence aggregation function from config. """ - if confidence_cfg is None: - confidence_cfg = OmegaConf.structured(ConfidenceConfig()) + # OmegaConf.structured ensures that post_init check is always executed + confidence_cfg = OmegaConf.structured( + ConfidenceConfig() if confidence_cfg is None else ConfidenceConfig(**confidence_cfg) + ) + self.confidence_measure_cfg = confidence_cfg.measure_cfg # extract the config self.preserve_word_confidence = confidence_cfg.get('preserve_word_confidence', False) @@ -216,7 +374,6 @@ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None): ) self.exclude_blank_from_confidence = confidence_cfg.get('exclude_blank', True) self.word_confidence_aggregation = confidence_cfg.get('aggregation', "min") - self.confidence_method_cfg = confidence_cfg.get('method_cfg', None) # define aggregation functions self.confidence_aggregation_bank = get_confidence_aggregation_bank() @@ -226,7 +383,13 @@ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None): if self.preserve_frame_confidence is False: if self.cfg.strategy in ['greedy', 'greedy_batch']: self.preserve_frame_confidence = self.cfg.greedy.get('preserve_frame_confidence', False) - self.confidence_method_cfg = self.cfg.greedy.get('confidence_method_cfg', None) + # OmegaConf.structured ensures that post_init check is always executed + confidence_measure_cfg = OmegaConf.structured(self.cfg.greedy).get('confidence_measure_cfg', None) + self.confidence_measure_cfg = ( + OmegaConf.structured(ConfidenceMeasureConfig()) + if confidence_measure_cfg is None + else OmegaConf.structured(ConfidenceMeasureConfig(**confidence_measure_cfg)) + ) @abstractmethod def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothesis]: diff --git a/nemo/collections/asr/parts/utils/confidence_metrics.py b/nemo/collections/asr/parts/utils/confidence_metrics.py index 28aa49959041..7d793c9df607 100644 --- a/nemo/collections/asr/parts/utils/confidence_metrics.py +++ b/nemo/collections/asr/parts/utils/confidence_metrics.py @@ -13,47 +13,94 @@ # limitations under the License. import math +import os +from pathlib import Path +from typing import List, Optional, Tuple, Union +import matplotlib.pyplot as plt import numpy as np -from sklearn.metrics import average_precision_score, log_loss, roc_auc_score +from sklearn.metrics import ( + PrecisionRecallDisplay, + RocCurveDisplay, + average_precision_score, + log_loss, + precision_recall_curve, + roc_auc_score, + roc_curve, +) -def auc_roc(y_true, y_score): +def auc_roc(y_true: Union[List[int], np.ndarray], y_score: Union[List[float], np.ndarray]) -> float: """Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores. + + Note: If only one class is present in y_true, 0.5 is returned. """ + y_true = np.array(y_true) + y_score = np.array(y_score) + assert len(y_true) == len(y_score) + assert np.all(y_true >= 0) and np.all(y_true <= 1) + if np.all(y_true == 0) or np.all(y_true == 1): + return 0.5 return roc_auc_score(y_true, y_score) -def auc_pr(y_true, y_score): +def auc_pr(y_true: Union[List[int], np.ndarray], y_score: Union[List[float], np.ndarray]) -> float: """Compute Area Under the Precision-Recall Curve (PR AUC) from prediction scores. + + Note: If only regatives are present in y_true, 0.0 is returned. """ + y_true = np.array(y_true) + y_score = np.array(y_score) + assert len(y_true) == len(y_score) + assert np.all(y_true >= 0) and np.all(y_true <= 1) + if np.all(y_true == 0): + return 0.0 return average_precision_score(y_true, y_score) -def auc_nt(y_true, y_score): +def auc_nt(y_true: Union[List[int], np.ndarray], y_score: Union[List[float], np.ndarray]) -> float: """Compute Area Under the Negative Predictive Value vs. True Negative Rate Curve (NT AUC) from prediction scores. This metric can be thought of as a PR AUC in which errors are treated as positives. + + Note: If only positives are present in y_true, 0.0 is returned. """ y_true = np.array(y_true) y_score = np.array(y_score) + assert len(y_true) == len(y_score) + assert np.all(y_true >= 0) and np.all(y_true <= 1) + if np.all(y_true == 1): + return 0.0 return average_precision_score(1 - y_true, 1 - y_score) -def nce(y_true, y_score): +def nce(y_true: Union[List[int], np.ndarray], y_score: Union[List[float], np.ndarray]) -> float: """Compute Normalized Cross Entropy (NCE) from prediction scores. Also known as the Normalized Mutual Information. NCE measures how close the correct prediction scores are to one and the incorrect prediction scores are to zero. Negative NCE values indicate that the classifier performs worse than the setting all prediction scores as the proportion of correct predictions. + + Note: If only one class is present in y_true, 0.5 is returned. """ - p = sum(y_true) / len(y_true) + y_true = np.array(y_true) + y_score = np.array(y_score) + assert len(y_true) == len(y_score) + assert np.all(y_true >= 0) and np.all(y_true <= 1) + if np.all(y_true == 0) or np.all(y_true == 1): + return -math.inf + p = y_true.mean() eps = 1e-15 Hp = -(math.log(p + eps) * p + math.log(1 - p + eps) * (1 - p)) return (Hp - log_loss(y_true, y_score)) / Hp -def ece(y_true, y_score, n_bins=100): +def ece( + y_true: Union[List[int], np.ndarray], + y_score: Union[List[float], np.ndarray], + n_bins: int = 100, + return_curve: bool = False, +) -> Union[float, Tuple[float, Tuple[List[int], List[float]]]]: """Compute Expected Calibration Error (ECE) from prediction scores. ECE measures how close the correct prediction scores are to one and the incorrect prediction scores are to zero. @@ -61,54 +108,159 @@ def ece(y_true, y_score, n_bins=100): """ y_true = np.array(y_true) y_score = np.array(y_score) + assert len(y_true) == len(y_score) + assert np.all(y_true >= 0) and np.all(y_true <= 1) py = np.array([1 - y_score, y_score]).T acc, conf = np.zeros(n_bins), np.zeros(n_bins) Bm = np.zeros(n_bins) + ece_curve = [] + thresholds = [] for m in range(n_bins): a, b = m / n_bins, (m + 1) / n_bins threshold = (a + b) / 2 + thresholds.append(threshold) py_index = (py.T[1] >= threshold).astype(int) py_value = py[np.arange(len(py_index)), py_index] bin_range = ((py_value > a) & (py_value <= b)).nonzero()[0] Bm[m] = len(bin_range) if Bm[m] > 0: - acc[m] = (py_index[bin_range] == y_true[bin_range]).sum() - conf[m] = py_value[bin_range].sum() - if Bm[m] != 0: - acc[m] /= Bm[m] - conf[m] /= Bm[m] - ece = 0 - for m in range(n_bins): - ece += Bm[m] * np.abs((acc[m] - conf[m])) - return ece / sum(Bm) + acc[m] = (py_index[bin_range] == y_true[bin_range]).sum() / Bm[m] + conf[m] = py_value[bin_range].sum() / Bm[m] + ece_curve.append(Bm[m] * np.abs(acc[m] - conf[m])) + ece = sum(ece_curve) / sum(Bm) + if return_curve: + return ece, (thresholds, ece_curve) + else: + return ece -def auc_yc(y_true, y_score, return_std_maximum=False, return_curve=False, n_bins=100): +def auc_yc( + y_true: Union[List[int], np.ndarray], + y_score: Union[List[float], np.ndarray], + n_bins: int = 100, + return_std_maximum: bool = False, + return_curve: bool = False, +) -> Union[ + float, + Tuple[float, Tuple[List[int], List[float]]], + Tuple[float, float, float], + Tuple[float, float, float, Tuple[List[int], List[float]]], +]: """Compute Area Under the Youden's Curve (YC AUC) from prediction scores. YC AUC represents the rate of the effective threshold range. If return_std_maximum is set to True, std and maximum values of the Youden's Curve are returned with the AUC. + + Note: If only one class is present in y_true, zeroes are returned for every entity. """ y_true = np.array(y_true) y_score = np.array(y_score) + thresholds = np.linspace(0, 1, n_bins + 1) + assert len(y_true) == len(y_score) + assert np.all(y_true >= 0) and np.all(y_true <= 1) + if np.all(y_true == 0) or np.all(y_true == 1): + if return_std_maximum and return_curve: + return 0.0, 0.0, 0.0, (thresholds, np.zeros(len(thresholds))) + elif return_std_maximum: + return 0.0, 0.0, 0.0 + elif return_curve: + return 0.0, (thresholds, np.zeros(len(thresholds))) + else: + return 0.0 mask_correct = y_true == 1 - count_correct = len(mask_correct.nonzero()[0]) - count_incorrect = len(y_true) - count_correct + count_correct = max(len(mask_correct.nonzero()[0]), 1) + count_incorrect = max(len(y_true) - count_correct, 1) y_score_correct = y_score[mask_correct] y_score_incorrect = y_score[~mask_correct] yc = [] - thresholds = [i / n_bins for i in range(0, n_bins + 1)] for threshold in thresholds: - tnr = len((np.array(y_score_incorrect) < threshold).nonzero()[0]) / count_incorrect - fnr = len((np.array(y_score_correct) < threshold).nonzero()[0]) / count_correct - yc.append(tnr - fnr) + tnr = len((y_score_incorrect < threshold).nonzero()[0]) / count_incorrect + fnr = len((y_score_correct < threshold).nonzero()[0]) / count_correct + yc.append(abs(tnr - fnr)) yc = np.array(yc) if return_std_maximum and return_curve: - return yc.mean(), yc.max(), yc.std(), (thresholds, yc) + return yc.mean(), yc.std(), yc.max(), (thresholds, yc) elif return_std_maximum: - return yc.mean(), yc.max(), yc.std() + return yc.mean(), yc.std(), yc.max() elif return_curve: return yc.mean(), (thresholds, yc) else: return yc.mean() + + +def save_confidence_hist(y_score: Union[List[float], np.ndarray], plot_dir: Union[str, Path], name: str = "hist"): + os.makedirs(plot_dir, exist_ok=True) + plt.hist(np.array(y_score), 50, range=(0, 1)) + plt.title(name) + plt.xlabel("Confidence score") + plt.ylabel("Count") + plt.savefig(Path(plot_dir) / Path(name + ".png"), dpi=300) + plt.clf() + + +def save_roc_curve( + y_true: Union[List[int], np.ndarray], + y_score: Union[List[float], np.ndarray], + plot_dir: Union[str, Path], + name: str = "roc", +): + assert len(y_true) == len(y_score) + os.makedirs(plot_dir, exist_ok=True) + fpr, tpr, _ = roc_curve(1 - np.array(y_true), 1 - np.array(y_score)) + RocCurveDisplay(fpr=fpr, tpr=tpr).plot() + plt.title(name) + plt.savefig(Path(plot_dir) / Path(name + ".png"), dpi=300) + plt.clf() + + +def save_pr_curve( + y_true: Union[List[int], np.ndarray], + y_score: Union[List[float], np.ndarray], + plot_dir: Union[str, Path], + name: str = "pr", +): + assert len(y_true) == len(y_score) + os.makedirs(plot_dir, exist_ok=True) + precision, recall, _ = precision_recall_curve(np.array(y_true), np.array(y_score)) + PrecisionRecallDisplay(precision=precision, recall=recall).plot() + plt.title(name) + plt.savefig(Path(plot_dir) / Path(name + ".png"), dpi=300) + plt.clf() + + +def save_nt_curve( + y_true: Union[List[int], np.ndarray], + y_score: Union[List[float], np.ndarray], + plot_dir: Union[str, Path], + name: str = "nt", +): + assert len(y_true) == len(y_score) + os.makedirs(plot_dir, exist_ok=True) + precision, recall, _ = precision_recall_curve(1 - np.array(y_true), 1 - np.array(y_score)) + PrecisionRecallDisplay(precision=precision, recall=recall).plot() + plt.title(name) + plt.savefig(Path(plot_dir) / Path(name + ".png"), dpi=300) + plt.clf() + + +def save_custom_confidence_curve( + thresholds: Union[List[float], np.ndarray], + values: Union[List[float], np.ndarray], + plot_dir: Union[str, Path], + name: str = "my_awesome_curve", + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, +): + assert len(thresholds) == len(values) + os.makedirs(plot_dir, exist_ok=True) + plt.plot(thresholds, values) + plt.xlim([0, 1]) + plt.ylim([0, 1]) + plt.title(name) + if xlabel is not None: + plt.xlabel(xlabel) + if ylabel is not None: + plt.ylabel(ylabel) + plt.savefig(Path(plot_dir) / Path(name + ".png"), dpi=300) + plt.clf() diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index e4f024d231ad..44f2abec584f 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -275,7 +275,9 @@ def generate_overlap_vad_seq( if out_dir: overlap_out_dir = out_dir else: - overlap_out_dir = frame_pred_dir + "/overlap_smoothing_output" + "_" + smoothing_method + "_" + str(overlap) + overlap_out_dir = os.path.join( + frame_pred_dir, "overlap_smoothing_output" + "_" + smoothing_method + "_" + str(overlap) + ) if not os.path.exists(overlap_out_dir): os.mkdir(overlap_out_dir) @@ -732,7 +734,7 @@ def generate_vad_segment_table( if not out_dir: out_dir_name = "seg_output_" for key in postprocessing_params: - out_dir_name = out_dir_name + str(key) + str(postprocessing_params[key]) + "-" + out_dir_name = out_dir_name + "-" + str(key) + str(postprocessing_params[key]) out_dir = os.path.join(vad_pred_dir, out_dir_name) diff --git a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py index 0ab0cb784273..906154213ea1 100644 --- a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py +++ b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -299,7 +299,7 @@ def create_spt_model( byte_fallback: If , fallback to a byte sequence of the character. split_digits: If true, digits are split into individual tokens. split_by_whitespace: Whether to respect white space while creating subwords. If False, will learn merges across whitespace. - split_by_unicode_script: Whether to include multiple Unicode scripts. Ex. is Arabic diacritics which are considered part of the letter (عِدَّةُ) + split_by_unicode_script: Whether to include multiple Unicode scripts. Ex. is Arabic diacritics which are considered part of the letter (عِدَّةُ) """ if not data_file or not os.path.exists(data_file): diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 756494f2f315..da3d03199c2e 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import numpy as np import torch @@ -40,12 +42,13 @@ def __init__( label_key: str = "answer", separate_prompt_and_response_with_newline: bool = False, answer_only_loss: bool = True, - truncation_field: str = "answer", + truncation_field: str = "context", pad_to_max_length: bool = False, # (@adithyare) allows for much faster training especially in PEFT settings. index_mapping_dir: str = None, prompt_template: str = None, virtual_tokens: int = 0, tokens_to_generate: int = 0, + memmap_workers: Optional[int] = None, ): """ file_path: Path to a JSONL GPT supervised fine-tuning dataset. Data is formatted as multiple JSON lines with each line formatted as follows. {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} @@ -94,7 +97,11 @@ def __init__( assert self.truncation_field in ["answer", "context"] self.indexed_dataset = JSONLMemMapDataset( - dataset_paths=[file_path], tokenizer=None, header_lines=0, index_mapping_dir=index_mapping_dir + dataset_paths=[file_path], + tokenizer=None, + header_lines=0, + index_mapping_dir=index_mapping_dir, + workers=memmap_workers, ) # Will be None after this call if `max_num_samples` is None diff --git a/nemo/collections/nlp/data/spellchecking_asr_customization/utils.py b/nemo/collections/nlp/data/spellchecking_asr_customization/utils.py index cda551189d78..7385f19b414a 100644 --- a/nemo/collections/nlp/data/spellchecking_asr_customization/utils.py +++ b/nemo/collections/nlp/data/spellchecking_asr_customization/utils.py @@ -764,12 +764,30 @@ def check_banned_replacements(src: str, dst: str) -> bool: # anticipated => anticipate if src.endswith("ed") and dst.endswith("e") and src[0:-2] == dst[0:-1]: return True + # blocks => blocked + if src.endswith("s") and dst.endswith("ed") and src[0:-1] == dst[0:-2]: + return True + # blocked => blocks + if src.endswith("ed") and dst.endswith("s") and src[0:-2] == dst[0:-1]: + return True + # lives => lived + if src.endswith("es") and dst.endswith("ed") and src[0:-2] == dst[0:-2]: + return True + # lived => lives + if src.endswith("ed") and dst.endswith("es") and src[0:-2] == dst[0:-2]: + return True # regarded => regard if src.endswith("ed") and src[0:-2] == dst: return True # regard => regarded if dst.endswith("ed") and dst[0:-2] == src: return True + # regardeding => regard + if src.endswith("ing") and src[0:-3] == dst: + return True + # regard => regarding + if dst.endswith("ing") and dst[0:-3] == src: + return True # longer => long if src.endswith("er") and src[0:-2] == dst: return True @@ -782,48 +800,102 @@ def check_banned_replacements(src: str, dst: str) -> bool: # discussing => discussed if src.endswith("ing") and dst.endswith("ed") and src[0:-3] == dst[0:-2]: return True + # live => living + if src.endswith("e") and dst.endswith("ing") and src[0:-1] == dst[0:-3]: + return True + # living => live + if src.endswith("ing") and dst.endswith("e") and src[0:-3] == dst[0:-1]: + return True # discussion => discussing if src.endswith("ion") and dst.endswith("ing") and src[0:-3] == dst[0:-3]: return True # discussing => discussion if src.endswith("ing") and dst.endswith("ion") and src[0:-3] == dst[0:-3]: return True + # alignment => aligning + if src.endswith("ment") and dst.endswith("ing") and src[0:-4] == dst[0:-3]: + return True + # aligning => alignment + if src.endswith("ing") and dst.endswith("ment") and src[0:-3] == dst[0:-4]: + return True # dispensers => dispensing if src.endswith("ers") and dst.endswith("ing") and src[0:-3] == dst[0:-3]: return True # dispensing => dispensers if src.endswith("ing") and dst.endswith("ers") and src[0:-3] == dst[0:-3]: return True + # integrate => integrity + if src.endswith("ate") and dst.endswith("ity") and src[0:-3] == dst[0:-3]: + return True + # integrity => integrate + if src.endswith("ity") and dst.endswith("ate") and src[0:-3] == dst[0:-3]: + return True # discussion => discussed if src.endswith("ion") and dst.endswith("ed") and src[0:-3] == dst[0:-2]: return True # discussed => discussion if src.endswith("ed") and dst.endswith("ion") and src[0:-2] == dst[0:-3]: return True + # anticipation => anticipate + if src.endswith("ion") and dst.endswith("e") and src[0:-3] == dst[0:-1]: + return True + # anticipate => anticipation + if src.endswith("e") and dst.endswith("ion") and src[0:-1] == dst[0:-3]: + return True # incremental => increment if src.endswith("ntal") and dst.endswith("nt") and src[0:-4] == dst[0:-2]: return True # increment => incremental if src.endswith("nt") and dst.endswith("ntal") and src[0:-2] == dst[0:-4]: return True + # national => nation + if src.endswith("nal") and dst.endswith("n") and src[0:-3] == dst[0:-1]: + return True + # nation => national + if src.endswith("n") and dst.endswith("nal") and src[0:-1] == dst[0:-3]: + return True + # significantly => significant + if src.endswith("ntly") and dst.endswith("nt") and src[0:-4] == dst[0:-2]: + return True + # significant => significantly + if src.endswith("nt") and dst.endswith("ntly") and src[0:-2] == dst[0:-4]: + return True # delivery => deliverer if src.endswith("ery") and dst.endswith("erer") and src[0:-3] == dst[0:-4]: return True # deliverer => delivery if src.endswith("erer") and dst.endswith("ery") and src[0:-4] == dst[0:-3]: return True + # deliver => deliverer + if src.endswith("er") and dst.endswith("erer") and src[0:-2] == dst[0:-4]: + return True + # deliverer => deliver + if src.endswith("erer") and dst.endswith("er") and src[0:-4] == dst[0:-2]: + return True # comparably => comparable if src.endswith("bly") and dst.endswith("ble") and src[0:-3] == dst[0:-3]: return True # comparable => comparably if src.endswith("ble") and dst.endswith("bly") and src[0:-3] == dst[0:-3]: return True + # comparably => comparability + if src.endswith("bly") and dst.endswith("bility") and src[0:-3] == dst[0:-6]: + return True + # comparability => comparably + if src.endswith("bility") and dst.endswith("bly") and src[0:-6] == dst[0:-3]: + return True # beautiful => beautifully if src.endswith("l") and dst.endswith("lly") and src[0:-1] == dst[0:-3]: return True # beautifully => beautiful if src.endswith("lly") and dst.endswith("l") and src[0:-3] == dst[0:-1]: return True + # active => actively + if src.endswith("e") and dst.endswith("ely") and src[0:-1] == dst[0:-3]: + return True + # actively => active + if src.endswith("ely") and dst.endswith("e") and src[0:-3] == dst[0:-1]: + return True # america => american if src.endswith("a") and dst.endswith("an") and src[0:-1] == dst[0:-2]: return True @@ -836,6 +908,18 @@ def check_banned_replacements(src: str, dst: str) -> bool: # investing => reinvesting if dst.startswith("re") and dst[2:] == src: return True + # unchanged => changed + if src.startswith("un") and src[2:] == dst: + return True + # changed => unchanged + if dst.startswith("un") and dst[2:] == src: + return True + # disrespected => respected + if src.startswith("dis") and src[3:] == dst: + return True + # respected => disrespected + if dst.startswith("dis") and dst[3:] == src: + return True # outperformance => performance if src.startswith("out") and src[3:] == dst: return True diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index 85186224494f..283cc0e499e8 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -221,6 +221,7 @@ def init_model(self, cfg: DictConfig, trainer: Trainer): "add_BOS": True, "all_probs": False, "compute_logprob": False, + "end_strings": self.cfg.inference.get('end_strings', ["<|endoftext|>"]), } elif self.cfg.get("report_validation_metric", False) and not hasattr(self.cfg, 'inference'): raise ValueError("Must provide inference parameters for reporting validation metric!") @@ -754,6 +755,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] "all_probs": inference_config["all_probs"], "compute_logprob": inference_config["compute_logprob"], "compute_attention_mask": inference_config.get("compute_attention_mask", True), + "end_strings": inference_config.get('end_strings', ["<|endoftext|>"]), } task_ids, processed_inputs = batch diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 6814c70d7a34..c0684b2d36d0 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -266,6 +266,9 @@ def _build_dataset(self, data_cfg, is_train=True): tokens_to_generate=data_cfg.get( 'tokens_to_generate', 0 ), # used at inference time to allocate tensor positions for tokens that will be generated by inf procedure. + memmap_workers=data_cfg.get( + 'memmap_workers', None + ), # used to set num. of workers to create the memmap index files ) datasets.append(dataset) @@ -385,6 +388,7 @@ def inference_step(self, dataloader_iter, batch_idx, mode, dataloader_idx=0): "add_BOS": False, "all_probs": False, "compute_logprob": False, + "end_strings": ["<|endoftext|>"], } result = megatron_gpt_generate( model=self, diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 032a7449c27e..d739efa88485 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -16,7 +16,7 @@ import hashlib import json import os -from typing import Any, Optional +from typing import Any, Mapping, Optional from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer @@ -385,3 +385,13 @@ def load_from_checkpoint( finally: cls._set_model_restore_state(is_being_restored=False) return checkpoint + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + # starting with trasformers v4.31.0, buffer for position_ids is persistent=False + if ( + self.bert_model is not None + and "position_ids" not in self.bert_model.embeddings._modules + and "bert_model.embeddings.position_ids" in state_dict + ): + del state_dict["bert_model.embeddings.position_ids"] + super(NLPModel, self).load_state_dict(state_dict, strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_utils.py b/nemo/collections/nlp/modules/common/megatron/megatron_utils.py index d901a00a343b..68437921f930 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_utils.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_utils.py @@ -14,13 +14,14 @@ # limitations under the License. import os +import shutil from typing import Dict, List import torch import wget from torch.hub import _get_torch_home -from nemo.utils import logging +from nemo.utils import get_rank, logging __all__ = [ "get_megatron_lm_model", @@ -202,16 +203,14 @@ def _download(path: str, url: str): if url is None: return None - if not os.path.exists(path): - master_device = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 - if not os.path.exists(path): - if master_device: - os.makedirs(MEGATRON_CACHE, exist_ok=True) - logging.info(f"Downloading from {url}") - wget.download(url, path) - # wait until the master process downloads the file and writes it to the cache dir - if torch.distributed.is_initialized(): - torch.distributed.barrier() + if get_rank.is_global_rank_zero() and not os.path.exists(path): + os.makedirs(MEGATRON_CACHE, exist_ok=True) + logging.info(f"Downloading from {url} to {path}") + downloaded_path = wget.download(url) + shutil.move(downloaded_path, path) + # wait until the master process downloads the file and writes it to the cache dir + if torch.distributed.is_initialized(): + torch.distributed.barrier() return path diff --git a/nemo/collections/nlp/modules/common/text_generation_server.py b/nemo/collections/nlp/modules/common/text_generation_server.py index 5eb69eefcc3e..a9d3b2097af7 100644 --- a/nemo/collections/nlp/modules/common/text_generation_server.py +++ b/nemo/collections/nlp/modules/common/text_generation_server.py @@ -141,6 +141,14 @@ def put(self): if not (1.0 <= repetition_penalty): return "repetition_penalty must be a positive number no less than 1.0" + end_strings = ['<|endoftext|>'] + if 'end_strings' in request.get_json(): + end_strings = request.get_json()['end_strings'] + if not isinstance(end_strings, list): + return "expect end_strings to be a list of strings" + if not all([isinstance(s, str) for s in end_strings]): + return "expect end_strings to be a list of strings" + min_tokens_to_generate = 0 if "min_tokens_to_generate" in request.get_json(): min_tokens_to_generate = request.get_json()["min_tokens_to_generate"] @@ -157,14 +165,6 @@ def put(self): if neighbors < 0: return "num of neighbors must be an integer no less than 0" - end_strings = ['<|endoftext|>'] - if 'end_strings' in request.get_json(): - end_strings = request.get_json()['end_strings'] - if not isinstance(end_strings, list): - return "expect end_strings to be a list of strings" - if not all([isinstance(s, str) for s in end_strings]): - return "expect end_strings to be a list of strings" - with lock: # Need to get lock to keep multiple threads from hitting code MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate extra = {} @@ -190,8 +190,8 @@ def put(self): top_p, greedy, repetition_penalty, - min_tokens_to_generate, end_strings=end_strings, + min_tokens_to_generate=min_tokens_to_generate, **extra, ) for k in output: diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index d0383741efdc..6d7e9abd6a99 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -69,6 +69,7 @@ def get_default_sampling_params(): "add_BOS": True, "all_probs": False, "compute_logprob": False, + "end_strings": ["<|endoftext|>", ""], } return sampling_params @@ -104,6 +105,7 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para top_p=sampling_params['top_p'], greedy=sampling_params['use_greedy'], repetition_penalty=sampling_params['repetition_penalty'], + end_strings=sampling_params['end_strings'], min_tokens_to_generate=length_params['min_length'], compute_attention_mask=sampling_params.get("compute_attention_mask", True), **strategy_args, @@ -125,6 +127,7 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para top_p=sampling_params['top_p'], greedy=sampling_params['use_greedy'], repetition_penalty=sampling_params['repetition_penalty'], + end_strings=sampling_params['end_strings'], min_tokens_to_generate=length_params['min_length'], **strategy_args, ) @@ -380,8 +383,8 @@ def synced_generate( compute_attention_mask=True, compute_logprob=False, repetition_penalty=1.2, - min_tokens_to_generate=0, end_strings=[], + min_tokens_to_generate=0, ): context_length = context_length_tensor.min().item() tokenizer = model.tokenizer @@ -393,6 +396,7 @@ def synced_generate( context_length_tensor, tokens_to_generate, all_probs, + compute_attention_mask=compute_attention_mask, temperature=temperature, ) else: @@ -475,8 +479,8 @@ def generate( compute_attention_mask=True, compute_logprob=False, repetition_penalty=1.0, - min_tokens_to_generate=0, end_strings=['<|endoftext|>'], + min_tokens_to_generate=0, **strategy_args, ) -> OutputType: """ @@ -560,8 +564,8 @@ def generate( top_p=top_p, greedy=greedy, repetition_penalty=repetition_penalty, - min_tokens_to_generate=min_tokens_to_generate, end_strings=end_strings, + min_tokens_to_generate=min_tokens_to_generate, ) special_tokens = set() if hasattr(tokenizer, 'pad_token') and tokenizer.pad_token is not None: @@ -822,6 +826,7 @@ def tab_sample_sequence_batch( context_lengths, tokens_to_generate, all_probs=True, + compute_attention_mask=True, type_ids=None, temperature=None, ): @@ -845,7 +850,7 @@ def tab_sample_sequence_batch( # initialize the batch with torch.no_grad(): context_length = context_lengths.min().item() - inference_strategy.init_batch(context_tokens, context_length) + inference_strategy.init_batch(context_tokens, context_length, compute_attention_mask) context = context_tokens[:, :context_length] # the context may start in the middle of the row, # calculate the offset according to the position of '\n' or '<|endoftext|>' @@ -879,7 +884,7 @@ def tab_sample_sequence_batch( while context_length < maxlen: batch, tensor_shape = inference_strategy.prepare_batch_at_step( - tokens, maxlen, micro_batch_size, counter, context_length + tokens, maxlen, micro_batch_size, counter, context_length, compute_attention_mask ) output = inference_strategy.forward_step(batch, tensor_shape) diff --git a/nemo/collections/nlp/modules/common/transformer/text_generation.py b/nemo/collections/nlp/modules/common/transformer/text_generation.py index a261e925691f..28db41b8a27a 100644 --- a/nemo/collections/nlp/modules/common/transformer/text_generation.py +++ b/nemo/collections/nlp/modules/common/transformer/text_generation.py @@ -37,6 +37,7 @@ class SamplingParam(TypedDict): add_BOS: bool # add the bos token at the begining of the prompt all_probs: bool # whether return the log prob for all the tokens in vocab compute_logprob: bool # a flag used to compute logprob of all the input text, a very special case of running inference, default False + end_strings: List[str] # generation will stop when one of these tokens is generated class OutputType(TypedDict): @@ -88,6 +89,7 @@ def generate( add_BOS: bool, Whether add the bos token at the begining of the prompt all_probs: bool # whether return the log prob for all the tokens in vocab compute_logprob: bool # a flag used to compute logprob of all the input text, a very special case of running inference, default False + end_strings: List[str] # generation will stop when one of these tokens is generated Default None, If it is None, use_greedy will be "True". Returns: OutputType: It generates the output in a dictionary type. It has the following keys: diff --git a/nemo/collections/tts/models/base.py b/nemo/collections/tts/models/base.py index 8ef147b9b145..fe19ae75a3b3 100644 --- a/nemo/collections/tts/models/base.py +++ b/nemo/collections/tts/models/base.py @@ -68,6 +68,18 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]': list_of_models.extend(subclass_models) return list_of_models + def set_export_config(self, args): + for k in ['enable_volume', 'enable_ragged_batches']: + if k in args: + self.export_config[k] = bool(args[k]) + args.pop(k) + if 'num_speakers' in args: + self.export_config['num_speakers'] = int(args['num_speakers']) + args.pop('num_speakers') + if 'emb_range' in args: + raise Exception('embedding range is not user-settable') + super().set_export_config(args) + class Vocoder(ModelPT, ABC): """ diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index dc598a9a76d1..8f0e06ea304d 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -772,6 +772,20 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]': ) list_of_models.append(model) + # en, multi speaker, LibriTTS, 16000 Hz + # stft 25ms 10ms matching ASR params + # for use during Enhlish ASR training/adaptation + model = PretrainedModelInfo( + pretrained_model_name="tts_en_fastpitch_for_asr_finetuning", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch_spectrogram_enhancer_for_asr_finetuning/versions/1.20.0/files/tts_en_fastpitch_for_asr_finetuning.nemo", + description="This model is trained on LibriSpeech, train-960 subset." + " STFT parameters follow those commonly used in ASR: 25 ms window, 10 ms hop." + " This model is supposed to be used with its companion SpetrogramEnhancer for " + " ASR fine-tuning. Usage for regular TTS tasks is not advised.", + class_=cls, + ) + list_of_models.append(model) + return list_of_models # Methods for model exportability diff --git a/nemo/collections/tts/models/spectrogram_enhancer.py b/nemo/collections/tts/models/spectrogram_enhancer.py index bcc7e69a10bf..ca2fe6122230 100644 --- a/nemo/collections/tts/models/spectrogram_enhancer.py +++ b/nemo/collections/tts/models/spectrogram_enhancer.py @@ -56,7 +56,7 @@ HingeLoss, ) from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor, to_device_recursive -from nemo.core import Exportable, ModelPT, typecheck +from nemo.core import Exportable, ModelPT, PretrainedModelInfo, typecheck from nemo.core.neural_types import LengthsType, MelSpectrogramType, NeuralType from nemo.core.neural_types.elements import BoolType from nemo.utils import logging @@ -277,7 +277,23 @@ def setup_validation_data(self, val_data_config): @classmethod def list_available_models(cls): - return [] + list_of_models = [] + + # en, multi speaker, LibriTTS, 16000 Hz + # stft 25ms 10ms matching ASR params + # for use during Enhlish ASR training/adaptation + model = PretrainedModelInfo( + pretrained_model_name="tts_en_spectrogram_enhancer_for_asr_finetuning", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch_spectrogram_enhancer_for_asr_finetuning/versions/1.20.0/files/tts_en_spectrogram_enhancer_for_asr_finetuning.nemo", + description="This model is trained to add details to synthetic spectrograms." + " It was trained on pairs of real-synthesized spectrograms generated by FastPitch." + " STFT parameters follow ASR with 25 ms window and 10 ms hop." + " It is supposed to be used in conjunction with that model for ASR training/adaptation.", + class_=cls, + ) + list_of_models.append(model) + + return list_of_models def log_illustration(self, target_spectrograms, input_spectrograms, enhanced_spectrograms, lengths): if self.global_rank != 0: diff --git a/nemo/core/neural_types/elements.py b/nemo/core/neural_types/elements.py index 10638a9c461a..f2de48da26d0 100644 --- a/nemo/core/neural_types/elements.py +++ b/nemo/core/neural_types/elements.py @@ -21,6 +21,7 @@ __all__ = [ 'ElementType', 'VoidType', + 'BoolType', 'ChannelType', 'AcousticEncodedRepresentation', 'AudioSignal', diff --git a/nemo/core/utils/numba_utils.py b/nemo/core/utils/numba_utils.py index 04010a2f7db4..9117b2ea1010 100644 --- a/nemo/core/utils/numba_utils.py +++ b/nemo/core/utils/numba_utils.py @@ -29,9 +29,6 @@ __NUMBA_MINIMUM_VERSION__ = os.environ.get("NEMO_NUMBA_MINVER", __NUMBA_DEFAULT_MINIMUM_VERSION__) __NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__ = "0.57.0" -NUMBA_FP16_SUPPORTED = model_utils.check_lib_version( - 'numba', __NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__, operator=operator.ge -)[0] NUMBA_INSTALLATION_MESSAGE = ( @@ -171,12 +168,16 @@ def is_numba_cuda_fp16_supported(return_reason: bool = False) -> Union[bool, Tup use_nvidia_binding = False reason += "Env variable `NUMBA_CUDA_USE_NVIDIA_BINDING` is not available or has not set to `1`." - if NUMBA_FP16_SUPPORTED: + numba_fp16_version_correct = model_utils.check_lib_version( + 'numba', __NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__, operator=operator.ge + )[0] + + if numba_fp16_version_correct: reason += f"Numba CUDA FP16 is supported in installed numba version." else: reason += f"Numba CUDA FP16 is not supported in installed numba version." - result = use_nvidia_binding and NUMBA_FP16_SUPPORTED + result = use_nvidia_binding and numba_fp16_version_correct if return_reason: return result, reason diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index 211ffdcdf11e..42a0b108944d 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import copy +import importlib import os from dataclasses import dataclass, is_dataclass from enum import Enum @@ -554,7 +555,7 @@ def check_lib_version(lib_name: str, checked_version: str, operator) -> Tuple[Op if '.' in lib_name: mod = import_class_by_path(lib_name) else: - mod = __import__(lib_name) + mod = importlib.import_module(lib_name) if hasattr(mod, '__version__'): lib_ver = version.Version(mod.__version__) diff --git a/requirements/requirements_lightning.txt b/requirements/requirements_lightning.txt index 100216aebc54..9c41c355e8cd 100644 --- a/requirements/requirements_lightning.txt +++ b/requirements/requirements_lightning.txt @@ -1,7 +1,6 @@ hydra-core>=1.2.0,<1.3 omegaconf>=2.2,<2.3 pytorch-lightning>=1.9.0,<=1.9.4 -pyyaml<6 # Pinned until omegaconf works with pyyaml>=6 torchmetrics>=0.11.0 transformers>=4.0.1 wandb diff --git a/scripts/asr_language_modeling/ngram_lm/install_beamsearch_decoders.sh b/scripts/asr_language_modeling/ngram_lm/install_beamsearch_decoders.sh index 558a84698f49..3ba337a6afd3 100755 --- a/scripts/asr_language_modeling/ngram_lm/install_beamsearch_decoders.sh +++ b/scripts/asr_language_modeling/ngram_lm/install_beamsearch_decoders.sh @@ -26,14 +26,15 @@ KENLM_MAX_ORDER=10 # Maximum order of KenLM model, also specified in the setup_o cd $NEMO_PATH if [ $(id -u) -eq 0 ]; then - alias aptupdate='apt-get update' - alias b2install='./b2' -else - alias aptupdate='sudo apt-get update' - alias b2install='sudo ./b2' + alias aptupdate='apt-get update' + alias b2install='./b2' + else + alias aptupdate='sudo apt-get update' + alias b2install='sudo ./b2' fi -aptupdate && apt-get upgrade -y && apt-get install -y liblzma-dev && rm -rf /var/lib/apt/lists/* # liblzma needed for flashlight decoder' +aptupdate && apt-get upgrade -y && apt-get install -y liblzma-dev && rm -rf /var/lib/apt/lists/* # liblzma needed for flashlight decoder + git clone https://github.com/NVIDIA/OpenSeq2Seq cd OpenSeq2Seq diff --git a/scripts/asr_language_modeling/ngram_lm/kenlm_utils.py b/scripts/asr_language_modeling/ngram_lm/kenlm_utils.py index 9e255ddc50ca..d9b48afab292 100644 --- a/scripts/asr_language_modeling/ngram_lm/kenlm_utils.py +++ b/scripts/asr_language_modeling/ngram_lm/kenlm_utils.py @@ -79,11 +79,8 @@ def setup_tokenizer(nemo_model_file): ) model = nemo_asr.models.ASRModel.from_pretrained(nemo_model_file, map_location=torch.device('cpu')) - if type(model.tokenizer).__name__ == 'AggregateTokenizer': - is_aggregate_tokenizer = True - else: - is_aggregate_tokenizer = False - + is_aggregate_tokenizer = False + tokenizer_nemo = None encoding_level = SUPPORTED_MODELS.get(type(model).__name__, None) if not encoding_level: logging.warning( @@ -91,7 +88,12 @@ def setup_tokenizer(nemo_model_file): ) encoding_level = 'char' - tokenizer_nemo = model.tokenizer + if encoding_level == 'subword': + if type(model.tokenizer).__name__ == 'AggregateTokenizer': + is_aggregate_tokenizer = True + + tokenizer_nemo = model.tokenizer + del model return tokenizer_nemo, encoding_level, is_aggregate_tokenizer @@ -117,10 +119,10 @@ def iter_files(source_path, dest_path, tokenizer, encoding_level, is_aggregate_t if isinstance(dest_path, str): with open(dest_path, 'w', encoding='utf-8') as f: for line in dataset: - f.write(line + "\n") + f.write(line[0] + "\n") else: # write to stdin of KenLM for line in dataset: - dest_path.write((line + '\n').encode()) + dest_path.write((line[0] + '\n').encode()) def read_train_file( diff --git a/scripts/asr_language_modeling/ngram_lm/ngram_merge.py b/scripts/asr_language_modeling/ngram_lm/ngram_merge.py index abffc6372518..b6606286ae5b 100644 --- a/scripts/asr_language_modeling/ngram_lm/ngram_merge.py +++ b/scripts/asr_language_modeling/ngram_lm/ngram_merge.py @@ -51,6 +51,7 @@ import torch import nemo.collections.asr as nemo_asr +from nemo.collections.asr.modules.rnnt import RNNTDecoder from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET from nemo.utils import logging @@ -207,9 +208,7 @@ def make_arpa(self, ngram_mod: str, ngram_arpa: str, force: bool): ] return subprocess.run(sh_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr,) - def test_perplexity( - self, mod_c: str, symbols: str, test_txt: str, nemo_model_file: str, tmp_path: str, force: bool - ) -> str: + def test_perplexity(self, mod_c: str, symbols: str, test_txt: str, nemo_model_file: str, tmp_path: str) -> str: """ Tests the perplexity of a given ngram model on a test file. @@ -229,12 +228,12 @@ def test_perplexity( 'Perplexity: 123.45' """ - test_far = farcompile(symbols, test_txt, tmp_path, nemo_model_file, force) + test_far = farcompile(symbols, test_txt, tmp_path, nemo_model_file) res_p = self.perplexity(mod_c, test_far) return res_p -def farcompile(symbols: str, text_file: str, tmp_path: str, nemo_model_file: str, force: bool,) -> str: +def farcompile(symbols: str, text_file: str, tmp_path: str, nemo_model_file: str) -> str: """ Compiles a text file into a FAR file using the given symbol table or tokenizer. @@ -253,43 +252,35 @@ def farcompile(symbols: str, text_file: str, tmp_path: str, nemo_model_file: str """ test_far = os.path.join(tmp_path, os.path.split(text_file)[1] + ".far") - if os.path.isfile(test_far) and not force: - logging.info("File " + test_far + " exists. Skipping.") - return None - else: - sh_args = [ - "farcompilestrings", - "--generate_keys=10", - "--fst_type=compact", - "--symbols=" + symbols, - "--keep_symbols", - ">", - test_far, - ] - - tokenizer, encoding_level, is_aggregate_tokenizer = kenlm_utils.setup_tokenizer(nemo_model_file) - - ps = subprocess.Popen( - " ".join(sh_args), shell=True, stdin=subprocess.PIPE, stdout=sys.stdout, stderr=sys.stderr, - ) - - kenlm_utils.iter_files( - source_path=[text_file], - dest_path=ps.stdin, - tokenizer=tokenizer, - encoding_level=encoding_level, - is_aggregate_tokenizer=is_aggregate_tokenizer, - verbose=1, - ) - stdout, stderr = ps.communicate() + sh_args = [ + "farcompilestrings", + "--generate_keys=10", + "--fst_type=compact", + "--symbols=" + symbols, + "--keep_symbols", + ">", + test_far, + ] + + tokenizer, encoding_level, is_aggregate_tokenizer = kenlm_utils.setup_tokenizer(nemo_model_file) + + ps = subprocess.Popen(" ".join(sh_args), shell=True, stdin=subprocess.PIPE, stdout=sys.stdout, stderr=sys.stderr,) + + kenlm_utils.iter_files( + source_path=[text_file], + dest_path=ps.stdin, + tokenizer=tokenizer, + encoding_level=encoding_level, + is_aggregate_tokenizer=is_aggregate_tokenizer, + verbose=1, + ) + stdout, stderr = ps.communicate() - exit_code = ps.returncode + exit_code = ps.returncode - command = " ".join(sh_args) - assert ( - exit_code == 0 - ), f"Exit_code must be 0.\n bash command: {command} \n stdout: {stdout} \n stderr: {stderr}" - return test_far + command = " ".join(sh_args) + assert exit_code == 0, f"Exit_code must be 0.\n bash command: {command} \n stdout: {stdout} \n stderr: {stderr}" + return test_far def make_kenlm(kenlm_bin_path: str, ngram_arpa: str, force: bool): @@ -310,7 +301,7 @@ def make_kenlm(kenlm_bin_path: str, ngram_arpa: str, force: bool): logging.info("File " + ngram_kenlm + " exists. Skipping.") return None else: - sh_args = [kenlm_bin_path, "trie", "-i", ngram_arpa, ngram_kenlm] + sh_args = [os.path.join(kenlm_bin_path, "build_binary"), "trie", "-i", ngram_arpa, ngram_kenlm] return subprocess.run(sh_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr,) @@ -336,12 +327,15 @@ def make_symbol_list(nemo_model_file, symbols, force): else: if nemo_model_file.endswith('.nemo'): asr_model = nemo_asr.models.ASRModel.restore_from(nemo_model_file, map_location=torch.device('cpu')) - vocab_size = len(asr_model.decoder.vocabulary) else: logging.warning( "nemo_model_file does not end with .nemo, therefore trying to load a pretrained model with this name." ) asr_model = nemo_asr.models.ASRModel.from_pretrained(nemo_model_file, map_location=torch.device('cpu')) + + if isinstance(asr_model.decoder, RNNTDecoder): + vocab_size = asr_model.decoder.blank_idx + else: vocab_size = len(asr_model.decoder.vocabulary) vocab = [chr(idx + DEFAULT_TOKEN_OFFSET) for idx in range(vocab_size)] @@ -389,8 +383,9 @@ def main( if not symbols: symbols = os.path.join(out_path, os.path.split(nemo_model_file)[1] + ".syms") make_symbol_list(nemo_model_file, symbols, force) - test_p = nm.test_perplexity(mod_c, symbols, test_file, nemo_model_file, out_path, force) - logging.info("Perplexity summary:" + test_p) + for test_f in test_file.split(","): + test_p = nm.test_perplexity(mod_c, symbols, test_f, nemo_model_file, out_path) + logging.info("Perplexity summary " + test_f + " : " + test_p) logging.info("Making ARPA and Kenlm model " + arpa_c) out = nm.make_arpa(mod_c, arpa_c, force) diff --git a/scripts/confidence_ensembles/build_ensemble.py b/scripts/confidence_ensembles/build_ensemble.py index b5685c63aa25..bc32a4f99840 100644 --- a/scripts/confidence_ensembles/build_ensemble.py +++ b/scripts/confidence_ensembles/build_ensemble.py @@ -59,7 +59,7 @@ python build_ensemble.py - tune_confidence_config.confidence_type='[entropy_renui_exp,entropy_tsallis_exp]' # only tune over this set + tune_confidence_config.confidence_type='[entropy_renyi_exp,entropy_tsallis_exp]' # only tune over this set tune_confidence_config.alpha='[0.1,0.5,1.0]' # only tune over this set You can check the dataclasses in this file for the full list of supported @@ -97,7 +97,7 @@ ) from nemo.collections.asr.parts.utils.asr_confidence_utils import ( ConfidenceConfig, - ConfidenceMethodConfig, + ConfidenceMeasureConfig, get_confidence_aggregation_bank, get_confidence_measure_bank, ) @@ -143,8 +143,8 @@ class TuneConfidenceConfig: # not including max prob, as there is always an entropy-based metric # that's better but otherwise including everything confidence_type: Tuple[str] = ( - "entropy_renui_exp", - "entropy_renui_lin", + "entropy_renyi_exp", + "entropy_renyi_lin", "entropy_tsallis_exp", "entropy_tsallis_lin", "entropy_gibbs_lin", @@ -214,14 +214,9 @@ class BuildEnsembleConfig: preserve_frame_confidence=True, exclude_blank=True, aggregation="mean", - method_cfg=ConfidenceMethodConfig( - name="entropy", - entropy_type="renui", - temperature=0.25, # this is not really temperature, but alpha, see https://arxiv.org/abs/2212.08703 - entropy_norm="lin", - ), + measure_cfg=ConfidenceMeasureConfig(name="entropy", entropy_type="renyi", alpha=0.25, entropy_norm="lin",), ) - temperature: float = 1.0 # this is a real temperature that will be applied to logits + temperature: float = 1.0 # this is optional, but can be used to change any aspect of the transcription # config, such as batch size or amp usage. Note that model, data and confidence diff --git a/scripts/confidence_ensembles/ensemble_config.yaml b/scripts/confidence_ensembles/ensemble_config.yaml index 954876a0c3cc..590318ee3b28 100644 --- a/scripts/confidence_ensembles/ensemble_config.yaml +++ b/scripts/confidence_ensembles/ensemble_config.yaml @@ -16,8 +16,8 @@ temperature: 1.0 confidence: exclude_blank: True aggregation: mean - method_cfg: + measure_cfg: name: entropy - entropy_type: renui - temperature: 0.25 # this is not really temperature, but alpha, see https://arxiv.org/abs/2212.08703 + entropy_type: renyi + alpha: 0.25 entropy_norm: lin diff --git a/scripts/export.py b/scripts/export.py index 4b21bc4ffd73..8fa44bb305f9 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -63,7 +63,7 @@ def get_args(argv): parser.add_argument("--device", default="cuda", help="Device to export for") parser.add_argument("--check-tolerance", type=float, default=0.01, help="tolerance for verification") parser.add_argument( - "--config", + "--export-config", metavar="KEY=VALUE", nargs='+', help="Set a number of key-value pairs to model.export_config dictionary " @@ -142,8 +142,14 @@ def nemo_export(argv): if args.cache_support: model.set_export_config({"cache_support": "True"}) - if args.config: - kv = dict(map(lambda s: s.split('='), args.config)) + if args.export_config: + kv = {} + for key_value in args.export_config: + lst = key_value.split("=") + if len(lst) != 2: + raise Exception("Use correct format for --export_config: k=v") + k, v = lst + kv[k] = v model.set_export_config(kv) autocast = nullcontext diff --git a/scripts/speech_recognition/confidence/benchmark_asr_confidence.py b/scripts/speech_recognition/confidence/benchmark_asr_confidence.py index a43e80b2bc3f..8922fe09176d 100644 --- a/scripts/speech_recognition/confidence/benchmark_asr_confidence.py +++ b/scripts/speech_recognition/confidence/benchmark_asr_confidence.py @@ -12,32 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import copy import json import os from dataclasses import dataclass, is_dataclass from pathlib import Path from typing import Optional -import matplotlib.pyplot as plt -import numpy as np import pytorch_lightning as pl -import texterrors import torch -from omegaconf import MISSING, OmegaConf, open_dict -from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay, precision_recall_curve, roc_curve +from omegaconf import MISSING, OmegaConf from sklearn.model_selection import ParameterGrid from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig from nemo.collections.asr.metrics.wer import CTCDecodingConfig -from nemo.collections.asr.models import ASRModel +from nemo.collections.asr.models import ASRModel, EncDecRNNTModel +from nemo.collections.asr.parts.utils.asr_confidence_benchmarking_utils import ( + apply_confidence_parameters, + run_confidence_benchmark, +) from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig -from nemo.collections.asr.parts.utils.confidence_metrics import auc_nt, auc_pr, auc_roc, auc_yc, ece, nce from nemo.core.config import hydra_runner from nemo.utils import logging - """ Get confidence metrics and curve plots for a given model, dataset, and confidence parameters. @@ -74,125 +70,10 @@ amp=True \ target_level="word" \ confidence_cfg.exclude_blank=False \ - 'grid_params="{\"aggregation\": [\"min\", \"prod\"], \"temperature\": [0.33, 0.5]}"' + 'grid_params="{\"aggregation\": [\"min\", \"prod\"], \"alpha\": [0.33, 0.5]}"' """ -def get_correct_marks(r, h): - """Get correct marks by aligning the reference text with a hypothesis. - - This method considers only insertions and substitutions as incorrect marks. - """ - return [ - a == b - for a, b in zip(*(texterrors.align_texts([str(rr) for rr in r], [str(hh) for hh in h], False)[:-1])) - if b != "" - ] - - -def get_token_targets_with_confidence(hyp): - return [[y, c] for y, c in zip(hyp.y_sequence, hyp.token_confidence)] - - -def get_word_targets_with_confidence(hyp): - return [[y, c] for y, c in zip(hyp.words, hyp.word_confidence)] - - -def run_benchmark( - model, batch_size, num_workers, is_rnnt, target_level, filepaths, reference_texts, plot_dir, autocast -): - """Run benchmark and plot histograms and curves. - - Returns: - Dictionary with benchmark results of the following scheme: - `level: (auc_roc, auc_pr, auc_nt, nce, ece, auc_yc, max_yc, std_yc)` with `level` being 'token' or 'word'. - """ - # transcribe audio - with autocast(): - with torch.no_grad(): - transcriptions = model.transcribe( - paths2audio_files=filepaths, batch_size=batch_size, return_hypotheses=True, num_workers=num_workers - ) - if is_rnnt: - transcriptions = transcriptions[0] - - levels = [] - if target_level != "word": - levels.append("token") - if target_level != "token": - levels.append("word") - results = {} - for level in levels: - if level == "token": - targets_with_confidence = [get_token_targets_with_confidence(tran) for tran in transcriptions] - correct_marks = [ - get_correct_marks(model.tokenizer.text_to_ids(r), model.tokenizer.text_to_ids(h.text)) - for r, h in zip(reference_texts, transcriptions) - ] - else: # "word" - targets_with_confidence = [get_word_targets_with_confidence(tran) for tran in transcriptions] - correct_marks = [get_correct_marks(r.split(), h.words) for r, h in zip(reference_texts, transcriptions)] - - y_true, y_score = np.array( - [[f, p[1]] for cm, twc in zip(correct_marks, targets_with_confidence) for f, p in zip(cm, twc)] - ).T - mask_correct = y_true == 1 - y_score_correct = y_score[mask_correct] - y_score_incorrect = y_score[~mask_correct] - result_yc = auc_yc(y_true, y_score, return_std_maximum=True, return_curve=True) - results[level] = [ - auc_roc(y_true, y_score), - auc_pr(y_true, y_score), - auc_nt(y_true, y_score), - nce(y_true, y_score), - ece(y_true, y_score), - ] + list(result_yc[:-1]) - - os.makedirs(plot_dir, exist_ok=True) - plt.hist(np.array(y_score_correct), 50, range=(0, 1)) - plt.savefig(plot_dir / Path(level + "_" + "hist_correct.png"), dpi=300) - plt.clf() - plt.hist(np.array(y_score_incorrect), 50, range=(0, 1)) - plt.savefig(plot_dir / Path(level + "_" + "hist_incorrect.png"), dpi=300) - plt.clf() - fpr, tpr, _ = roc_curve(1 - y_true, 1 - y_score) - RocCurveDisplay(fpr=fpr, tpr=tpr).plot() - plt.savefig(plot_dir / Path(level + "_" + "roc.png"), dpi=300) - plt.clf() - precision, recall, _ = precision_recall_curve(y_true, y_score) - PrecisionRecallDisplay(precision=precision, recall=recall).plot() - plt.savefig(plot_dir / Path(level + "_" + "pr.png"), dpi=300) - plt.clf() - precision, recall, _ = precision_recall_curve(1 - y_true, 1 - y_score) - PrecisionRecallDisplay(precision=precision, recall=recall).plot() - plt.savefig(plot_dir / Path(level + "_" + "nt.png"), dpi=300) - plt.clf() - plt.plot(*result_yc[-1]) - plt.ylim([0, 1]) - plt.savefig(plot_dir / Path(level + "_" + "yc.png"), dpi=300) - plt.clf() - - return results - - -def apply_parameters(decoding_cfg, hp): - """Apply parameters from a parameter grid to a decoding config. - - Returns: - Updated decoding config. - """ - new_decoding_cfg = copy.deepcopy(decoding_cfg) - confidence_cfg_fields = ("aggregation", "exclude_blank") - confidence_method_cfg_fields = ("name", "temperature", "entropy_type", "entropy_norm") - with open_dict(new_decoding_cfg): - for p, v in hp.items(): - if p in confidence_cfg_fields: - new_decoding_cfg.confidence_cfg[p] = v - elif p in confidence_method_cfg_fields: - new_decoding_cfg.confidence_cfg.method_cfg[p] = v - return new_decoding_cfg - - def get_experiment_params(cfg): """Get experiment parameters from a confidence config and generate the experiment name. @@ -202,23 +83,23 @@ def get_experiment_params(cfg): """ blank = "no_blank" if cfg.exclude_blank else "blank" aggregation = cfg.aggregation - method_name = cfg.method_cfg.name - temperature = cfg.method_cfg.temperature + method_name = cfg.measure_cfg.name + alpha = cfg.measure_cfg.alpha if method_name == "entropy": - entropy_type = cfg.method_cfg.entropy_type - entropy_norm = cfg.method_cfg.entropy_norm + entropy_type = cfg.measure_cfg.entropy_type + entropy_norm = cfg.measure_cfg.entropy_norm experiment_param_list = [ aggregation, str(cfg.exclude_blank), method_name, entropy_type, entropy_norm, - str(temperature), + str(alpha), ] - experiment_str = "-".join([aggregation, blank, method_name, entropy_type, entropy_norm, str(temperature)]) + experiment_str = "-".join([aggregation, blank, method_name, entropy_type, entropy_norm, str(alpha)]) else: - experiment_param_list = [aggregation, str(cfg.exclude_blank), method_name, "-", "-", str(temperature)] - experiment_str = "-".join([aggregation, blank, method_name, str(temperature)]) + experiment_param_list = [aggregation, str(cfg.exclude_blank), method_name, "-", "-", str(alpha)] + experiment_str = "-".join([aggregation, blank, method_name, str(alpha)]) return experiment_param_list, experiment_str @@ -294,7 +175,7 @@ def main(cfg: ConfidenceBenchmarkingConfig): asr_model = asr_model.eval() # Check if ctc or rnnt model - is_rnnt = hasattr(asr_model, 'joint') + is_rnnt = isinstance(asr_model, EncDecRNNTModel) # Check that the model has the `change_decoding_strategy` method if not hasattr(asr_model, 'change_decoding_strategy'): @@ -317,14 +198,10 @@ def main(cfg: ConfidenceBenchmarkingConfig): reference_texts.append(item['text']) # setup AMP (optional) + autocast = None if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): logging.info("AMP enabled!\n") autocast = torch.cuda.amp.autocast - else: - - @contextlib.contextmanager - def autocast(): - yield # do grid-based benchmarking if grid_params is provided, otherwise a regular one work_dir = Path(cfg.output_dir) @@ -338,7 +215,7 @@ def autocast(): "method_name", "entropy_type", "entropy_norm", - "temperature", + "alpha", "target_level", "auc_roc", "auc_pr", @@ -346,8 +223,8 @@ def autocast(): "nce", "ece", "auc_yc", - "max_yc", "std_yc", + "max_yc", ] ) + "\n" @@ -374,17 +251,16 @@ def autocast(): f.flush() for i, hp in enumerate(hp_grid): logging.info(f"Run # {i + 1}, grid: `{hp}`") - asr_model.change_decoding_strategy(apply_parameters(asr_model.cfg.decoding, hp)) + asr_model.change_decoding_strategy(apply_confidence_parameters(asr_model.cfg.decoding, hp)) param_list, experiment_name = get_experiment_params(asr_model.cfg.decoding.confidence_cfg) plot_dir = work_dir / Path(experiment_name) - results = run_benchmark( + results = run_confidence_benchmark( asr_model, - cfg.batch_size, - cfg.num_workers, - is_rnnt, cfg.target_level, filepaths, reference_texts, + cfg.batch_size, + cfg.num_workers, plot_dir, autocast, ) @@ -406,11 +282,10 @@ def autocast(): with open(report_file, "tw", encoding="utf-8") as f: f.write(report_legend) f.flush() - results = run_benchmark( + results = run_confidence_benchmark( asr_model, cfg.batch_size, cfg.num_workers, - is_rnnt, cfg.target_level, filepaths, reference_texts, diff --git a/tests/collections/asr/confidence/test_asr_confidence.py b/tests/collections/asr/confidence/test_asr_confidence.py new file mode 100644 index 000000000000..11b127424908 --- /dev/null +++ b/tests/collections/asr/confidence/test_asr_confidence.py @@ -0,0 +1,144 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import math +import tempfile +from pathlib import Path + +import numpy as np +import pytest +from omegaconf import OmegaConf +from pytorch_lightning import Trainer + +from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig +from nemo.collections.asr.metrics.wer import CTCDecodingConfig +from nemo.collections.asr.models import ASRModel, EncDecCTCModelBPE, EncDecRNNTBPEModel +from nemo.collections.asr.parts.submodules.ctc_greedy_decoding import GreedyCTCInferConfig +from nemo.collections.asr.parts.submodules.rnnt_greedy_decoding import GreedyRNNTInferConfig +from nemo.collections.asr.parts.utils.asr_confidence_benchmarking_utils import run_confidence_benchmark +from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig + +# both models recognize the test data without errors, thus every metric except ece return default values +ECE_VALUES = {("token", "ctc"): 0.87, ("token", "rnnt"): 0.82, ("word", "ctc"): 0.91, ("word", "rnnt"): 0.88} + +TOL_DEGREE = 2 +TOL = 1 / math.pow(10, TOL_DEGREE) + + +@pytest.fixture(scope="module") +def conformer_ctc_bpe_model(): + model = EncDecCTCModelBPE.from_pretrained(model_name="stt_en_conformer_ctc_small") + model.set_trainer(Trainer(devices=1, accelerator="cpu")) + model = model.eval() + return model + + +@pytest.fixture(scope="module") +def conformer_rnnt_bpe_model(): + model = EncDecRNNTBPEModel.from_pretrained(model_name="stt_en_conformer_transducer_small") + model.set_trainer(Trainer(devices=1, accelerator="cpu")) + model = model.eval() + return model + + +@pytest.mark.with_downloads +@pytest.fixture(scope="module") +# @pytest.fixture +def audio_and_texts(test_data_dir): + # get filenames and reference texts from manifest + filepaths = [] + reference_texts = [] + manifest = Path(test_data_dir) / Path("asr/an4_val.json") + with open(manifest, 'r') as f: + for line in f: + item = json.loads(line) + # alaptev: maybe fix those paths in the manifest? + audio_file = Path(item['audio_filepath'].replace("/data/", "/.data/")) + filepaths.append(str(audio_file.absolute())) + reference_texts.append(item['text']) + return filepaths, reference_texts + + +class TestASRConfidenceBenchmark: + @pytest.mark.integration + @pytest.mark.with_downloads + @pytest.mark.parametrize('model_name', ("ctc", "rnnt")) + @pytest.mark.parametrize('target_level', ("token", "word")) + def test_run_confidence_benchmark( + self, model_name, target_level, audio_and_texts, conformer_ctc_bpe_model, conformer_rnnt_bpe_model + ): + model = conformer_ctc_bpe_model if model_name == "ctc" else conformer_rnnt_bpe_model + assert isinstance(model, ASRModel) + filepaths, reference_texts = audio_and_texts + confidence_cfg = ( + ConfidenceConfig(preserve_token_confidence=True) + if target_level == "token" + else ConfidenceConfig(preserve_word_confidence=True) + ) + model.change_decoding_strategy( + RNNTDecodingConfig(fused_batch_size=-1, strategy="greedy_batch", confidence_cfg=confidence_cfg) + if model_name == "rnnt" + else CTCDecodingConfig(confidence_cfg=confidence_cfg) + ) + with tempfile.TemporaryDirectory() as tmpdir: + assert np.allclose( + np.array( + run_confidence_benchmark(model, target_level, filepaths, reference_texts, plot_dir=tmpdir)[ + target_level + ] + ), + np.array([0.5, 1.0, 0.0, -math.inf, ECE_VALUES[(target_level, model_name)], 0.0, 0.0, 0.0]), + atol=TOL, + ) + + @pytest.mark.integration + @pytest.mark.with_downloads + @pytest.mark.parametrize('model_name', ("ctc", "rnnt")) + @pytest.mark.parametrize('arg', ("method_cfg", "temperature", "all")) + def test_deprecated_config_args(self, model_name, arg, conformer_ctc_bpe_model, conformer_rnnt_bpe_model): + assert ConfidenceConfig().measure_cfg.alpha == 0.33, "default `alpha` is supposed to be 0.33" + model = conformer_ctc_bpe_model if model_name == "ctc" else conformer_rnnt_bpe_model + assert isinstance(model, ASRModel) + if arg == "all": + conf = OmegaConf.create({"temperature": 0.5}) + test_args_main = {"method_cfg": conf} + test_args_greedy = {"confidence_method_cfg": conf} + elif arg == "method_cfg": + conf = OmegaConf.create({"alpha": 0.5}) + test_args_main = {"method_cfg": conf} + test_args_greedy = {"confidence_method_cfg": conf} + elif arg == "temperature": + conf = OmegaConf.create({"temperature": 0.5}) + test_args_main = {"measure_cfg": conf} + test_args_greedy = {"confidence_measure_cfg": conf} + else: + raise NotImplementedError(arg) + confidence_cfg = ConfidenceConfig(preserve_word_confidence=True, **test_args_main) + model.change_decoding_strategy( + RNNTDecodingConfig(fused_batch_size=-1, strategy="greedy", confidence_cfg=confidence_cfg) + if model_name == "rnnt" + else CTCDecodingConfig(confidence_cfg=confidence_cfg) + ) + assert model.cfg.decoding.confidence_cfg.measure_cfg.alpha == 0.5 + model.change_decoding_strategy( + RNNTDecodingConfig( + fused_batch_size=-1, + strategy="greedy", + greedy=GreedyRNNTInferConfig(preserve_frame_confidence=True, **test_args_greedy), + ) + if model_name == "rnnt" + else CTCDecodingConfig(greedy=GreedyCTCInferConfig(preserve_frame_confidence=True, **test_args_greedy)) + ) + assert model.cfg.decoding.greedy.confidence_measure_cfg.alpha == 0.5 diff --git a/tests/collections/asr/confidence/test_asr_confidence_metrics.py b/tests/collections/asr/confidence/test_asr_confidence_metrics.py new file mode 100644 index 000000000000..fde5f322a988 --- /dev/null +++ b/tests/collections/asr/confidence/test_asr_confidence_metrics.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import tempfile + +import numpy as np +import pytest +from scipy.stats import uniform + +from nemo.collections.asr.parts.utils.confidence_metrics import ( + auc_nt, + auc_pr, + auc_roc, + auc_yc, + ece, + nce, + save_confidence_hist, + save_custom_confidence_curve, + save_nt_curve, + save_pr_curve, + save_roc_curve, +) + +# set convenient name2metric mapping +name2metric = { + f.__name__: (f, ans) + for f, ans in zip((auc_roc, auc_pr, auc_nt, auc_yc, ece, nce), (0.833, 0.917, 0.833, 0.421, 0.232, 0.403)) +} +# ece does not have a default value +name2metric_all_correct = { + f.__name__: (f, ans) for f, ans in zip((auc_roc, auc_pr, auc_nt, auc_yc, nce), (0.5, 1.0, 0.0, 0.0, -math.inf)) +} +name2metric_all_incorrect = { + f.__name__: (f, ans) for f, ans in zip((auc_roc, auc_pr, auc_nt, auc_yc, nce), (0.5, 0.0, 1.0, 0.0, -math.inf)) +} + +# Initialize data +Y_TRUE = [1, 0, 0, 1, 1] +Y_TRUE_ALL_CORRECT = [1, 1, 1, 1, 1] +Y_TRUE_ALL_INCORRECT = [0, 0, 0, 0, 0] +Y_SCORE = [0.6, 0.7, 0.02, 0.95, 0.8] +Y_TRUE_RANDOM = np.random.choice(2, 1000, p=[0.2, 0.8]) +# probability distribution with mean ~= 0.65 and std ~= 0.25 +Y_SCORE_RANDOM = uniform.rvs(size=1000, loc=0.5, scale=0.5) - 0.5 * np.random.choice(2, 1000, p=[0.8, 0.2]) + +TOL_DEGREE = 3 +TOL = 1 / math.pow(10, TOL_DEGREE) + + +class TestConfidenceMetrics: + @pytest.mark.unit + @pytest.mark.parametrize('metric_name', name2metric.keys()) + def test_metric_main(self, metric_name): + metric, ans = name2metric[metric_name] + + assert round(metric(Y_TRUE, Y_SCORE), TOL_DEGREE) == ans + + @pytest.mark.unit + @pytest.mark.parametrize('metric_name', name2metric_all_correct.keys()) + def test_metric_all_correct(self, metric_name): + metric, ans = name2metric_all_correct[metric_name] + + assert round(metric(Y_TRUE_ALL_CORRECT, Y_SCORE), TOL_DEGREE) == ans + + @pytest.mark.unit + @pytest.mark.parametrize('metric_name', name2metric_all_incorrect.keys()) + def test_metric_all_incorrect(self, metric_name): + metric, ans = name2metric_all_incorrect[metric_name] + + assert round(metric(Y_TRUE_ALL_INCORRECT, Y_SCORE), TOL_DEGREE) == ans + + @pytest.mark.unit + def test_metric_auc_yc_aux(self): + n_bins = 10 + result, result_std, result_max, (thresholds, yc_curve) = auc_yc( + Y_TRUE, Y_SCORE, n_bins=n_bins, return_std_maximum=True, return_curve=True + ) + + assert round(result_std, TOL_DEGREE) == 0.228 + assert round(result_max, TOL_DEGREE) == 0.667 + assert np.allclose(np.array(thresholds), np.array([i / n_bins for i in range(0, n_bins + 1)]), atol=TOL) + assert np.allclose( + np.array(yc_curve), np.array([0.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.167, 0.667, 0.667, 0.333, 0.0]), atol=TOL + ) + + +class TestSaveConfidencePlot: + @pytest.mark.unit + def test_save_confidence_hist(self): + with tempfile.TemporaryDirectory() as tmpdir: + save_confidence_hist(Y_SCORE_RANDOM, tmpdir) + + @pytest.mark.unit + @pytest.mark.parametrize('plot_func', (save_roc_curve, save_pr_curve, save_nt_curve)) + def test_save_simple_confidence_curve(self, plot_func): + with tempfile.TemporaryDirectory() as tmpdir: + plot_func(Y_TRUE_RANDOM, Y_SCORE_RANDOM, tmpdir) + + @pytest.mark.unit + def test_save_custom_confidence_curve(self): + with tempfile.TemporaryDirectory() as tmpdir: + ranges = np.arange(0, 1, 0.01) + save_custom_confidence_curve(ranges, ranges, tmpdir) diff --git a/tests/collections/asr/confidence/test_asr_confidence_primitives.py b/tests/collections/asr/confidence/test_asr_confidence_primitives.py new file mode 100644 index 000000000000..d1111406ca62 --- /dev/null +++ b/tests/collections/asr/confidence/test_asr_confidence_primitives.py @@ -0,0 +1,142 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import pytest +import torch + +from nemo.collections.asr.parts.utils.asr_confidence_utils import ( + get_confidence_aggregation_bank, + get_confidence_measure_bank, +) + +# Initialize probability vectors +VOCAB_SIZES = (100, 1000, 10000) +ONE_VEC_SET, ZERO_VEC_SET, RAND_VEC_SET, OVERFIT_RAND_VEC_SET = {}, {}, {}, {} +for vocab_size in VOCAB_SIZES: + # batch size 2 to test different positions of probability one + ONE_VEC_SET[vocab_size] = torch.nan_to_num( + torch.cat( + [ + torch.tensor([[0] + [float('-inf')] * (vocab_size - 1)]), + torch.tensor([[float('-inf')] * (vocab_size - 3) + [0] + [float('-inf')] * 2]), + ] + ) + ) + ZERO_VEC_SET[vocab_size] = torch.nan_to_num(torch.tensor([[math.log(1 / vocab_size)] * vocab_size] * 2)) + # batch size 1 + rand_logit = torch.rand((1, vocab_size)) + rand_logit_overfit = rand_logit.clone() + rand_logit_overfit[0, 0] += vocab_size + RAND_VEC_SET[vocab_size] = torch.nan_to_num(torch.nn.functional.log_softmax(rand_logit, -1)) + OVERFIT_RAND_VEC_SET[vocab_size] = torch.nan_to_num(torch.nn.functional.log_softmax(rand_logit_overfit, -1)) +AGGREGATION_VEC_SIMPLE = [0.0, 0.5, 1] + +TOL_DEGREE = 6 +TOL = 1 / math.pow(10, TOL_DEGREE) + + +def get_measure_parametrize_ranges(): + confidence_measure_bank = {} + alpha_range = (0.25, 0.5, 1.0) + bank_exception = None + try: + confidence_measure_bank = get_confidence_measure_bank() + except Exception as e: + alpha_range = () + bank_exception = e + return confidence_measure_bank, alpha_range, bank_exception + + +def get_aggregation_parametrize_ranges(): + confidence_aggregation_bank = {} + bank_exception = None + try: + confidence_aggregation_bank = get_confidence_aggregation_bank() + except Exception as e: + bank_exception = e + return confidence_aggregation_bank, bank_exception + + +class TestConfidenceMeasureBank: + measure_bank, alphas, bank_build_exception = get_measure_parametrize_ranges() + + @pytest.mark.unit + def test_measure_bank(self): + if self.bank_build_exception is not None: + raise self.bank_build_exception + + assert isinstance(self.measure_bank, dict) + assert len(self.measure_bank) > 0 + + @pytest.mark.unit + @pytest.mark.parametrize('measure_name', measure_bank.keys()) + @pytest.mark.parametrize('alpha', alphas) + @pytest.mark.parametrize('vocab_size', VOCAB_SIZES) + def test_confidence_measures_one(self, measure_name, alpha, vocab_size): + measure = self.measure_bank[measure_name] + + assert torch.allclose(measure(ONE_VEC_SET[vocab_size], vocab_size, alpha), torch.tensor([1.0, 1.0]), atol=TOL) + + @pytest.mark.unit + @pytest.mark.parametrize('measure_name', measure_bank.keys()) + @pytest.mark.parametrize('alpha', alphas) + @pytest.mark.parametrize('vocab_size', VOCAB_SIZES) + def test_confidence_measures_zero(self, measure_name, alpha, vocab_size): + measure = self.measure_bank[measure_name] + + assert torch.allclose(measure(ZERO_VEC_SET[vocab_size], vocab_size, alpha), torch.tensor([0.0, 0.0]), atol=TOL) + + @pytest.mark.unit + @pytest.mark.parametrize('measure_name', measure_bank.keys()) + @pytest.mark.parametrize('alpha', alphas) + @pytest.mark.parametrize('vocab_size', VOCAB_SIZES) + def test_confidence_measures_partial_order(self, measure_name, alpha, vocab_size): + measure = self.measure_bank[measure_name] + value_normal = round(float(measure(RAND_VEC_SET[vocab_size], vocab_size, alpha)[0]), TOL_DEGREE) + value_overfit = round(float(measure(OVERFIT_RAND_VEC_SET[vocab_size], vocab_size, alpha)[0]), TOL_DEGREE) + + assert 0 <= value_normal < value_overfit <= 1, ( + measure(RAND_VEC_SET[vocab_size], vocab_size, alpha), + measure(OVERFIT_RAND_VEC_SET[vocab_size], vocab_size, alpha), + ) + + +class TestConfidenceAggregationBank: + aggregation_bank, bank_build_exception = get_aggregation_parametrize_ranges() + + @pytest.mark.unit + def test_aggregation_bank(self): + if self.bank_build_exception is not None: + raise self.bank_build_exception + + assert isinstance(self.aggregation_bank, dict) + assert len(self.aggregation_bank) > 0 + + @pytest.mark.unit + @pytest.mark.parametrize('aggregation_name', aggregation_bank.keys()) + def test_confidence_agregation_simple(self, aggregation_name): + # alaptev: would skipif work with parametrize arguments? + if aggregation_name not in ("mean", "min", "max", "prod"): + pytest.skip(f"{aggregation_name} is not a simple aggregation") + aggregation = self.aggregation_bank[aggregation_name] + if aggregation_name == "mean": + assert aggregation(AGGREGATION_VEC_SIMPLE) == 0.5 + elif aggregation_name == "min": + assert aggregation(AGGREGATION_VEC_SIMPLE) == 0.0 + if aggregation_name == "max": + assert aggregation(AGGREGATION_VEC_SIMPLE) == 1.0 + if aggregation_name == "prod": + assert aggregation(AGGREGATION_VEC_SIMPLE) == 0.0 diff --git a/tests/collections/asr/test_asr_classification_model.py b/tests/collections/asr/test_asr_classification_model.py index 876bb6073a38..3888cb30204c 100644 --- a/tests/collections/asr/test_asr_classification_model.py +++ b/tests/collections/asr/test_asr_classification_model.py @@ -94,8 +94,8 @@ def frame_classification_model(): } decoder = { - 'cls': 'nemo.collections.asr.modules.ConvASRDecoderClassification', - 'params': {'feat_in': 32, 'num_classes': 5,}, + 'cls': 'nemo.collections.common.parts.MultiLayerPerceptron', + 'params': {'hidden_size': 32, 'num_classes': 5,}, } modelConfig = DictConfig( diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py index 22926b6516ee..8687ed683833 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -242,7 +242,8 @@ def test_decoding_change(self, hybrid_asr_model): @pytest.mark.unit def test_GreedyRNNTInferConfig(self): - IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + # confidence_method_cfg is deprecated + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index', 'confidence_method_cfg'] result = assert_dataclass_signature_match( greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyRNNTInferConfig, ignore_args=IGNORE_ARGS @@ -256,7 +257,8 @@ def test_GreedyRNNTInferConfig(self): @pytest.mark.unit def test_GreedyBatchedRNNTInferConfig(self): - IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + # confidence_method_cfg is deprecated + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index', 'confidence_method_cfg'] result = assert_dataclass_signature_match( greedy_decode.GreedyBatchedRNNTInfer, greedy_decode.GreedyBatchedRNNTInferConfig, ignore_args=IGNORE_ARGS diff --git a/tests/collections/asr/test_asr_metrics.py b/tests/collections/asr/test_asr_metrics.py index 9a43ed4e2b90..2c4ec0953444 100644 --- a/tests/collections/asr/test_asr_metrics.py +++ b/tests/collections/asr/test_asr_metrics.py @@ -32,6 +32,7 @@ CTCDecodingConfig, word_error_rate, word_error_rate_detail, + word_error_rate_per_utt, ) from nemo.collections.asr.metrics.wer_bpe import WERBPE, CTCBPEDecoding, CTCBPEDecodingConfig from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis @@ -136,6 +137,15 @@ def test_wer_function(self): 0.0, ) + assert word_error_rate_per_utt(hypotheses=['kat'], references=['cat']) == ([1.0], 1.0) + assert word_error_rate_per_utt(hypotheses=['cat', ''], references=['', 'gpu']) == ([float("inf"), 1.0], 2.0) + assert word_error_rate_per_utt( + hypotheses=['ducuti motorcycle', 'G P U'], references=['ducati motorcycle', 'GPU'] + ) == ([0.5, 3.0], 4 / 3) + assert word_error_rate_per_utt( + hypotheses=['ducuti motorcycle', 'G P U'], references=['ducati motorcycle', 'GPU'], use_cer=True + ) == ([1 / 17, 2 / 3], 0.15) + @pytest.mark.unit @pytest.mark.parametrize("batch_dim_index", [0, 1]) @pytest.mark.parametrize("test_wer_bpe", [False, True]) diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index 68f1e38f797b..775a146c74c4 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -242,7 +242,8 @@ def test_decoding_change(self, asr_model): @pytest.mark.unit def test_GreedyRNNTInferConfig(self): - IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + # confidence_method_cfg is deprecated + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index', 'confidence_method_cfg'] result = assert_dataclass_signature_match( greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyRNNTInferConfig, ignore_args=IGNORE_ARGS @@ -256,7 +257,8 @@ def test_GreedyRNNTInferConfig(self): @pytest.mark.unit def test_GreedyBatchedRNNTInferConfig(self): - IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index'] + # confidence_method_cfg is deprecated + IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index', 'confidence_method_cfg'] result = assert_dataclass_signature_match( greedy_decode.GreedyBatchedRNNTInfer, greedy_decode.GreedyBatchedRNNTInferConfig, ignore_args=IGNORE_ARGS diff --git a/tests/collections/asr/test_confidence_ensembles.py b/tests/collections/asr/test_confidence_ensembles.py index ad14a2a7e6ff..b8b027dd3426 100644 --- a/tests/collections/asr/test_confidence_ensembles.py +++ b/tests/collections/asr/test_confidence_ensembles.py @@ -19,7 +19,7 @@ from nemo.collections.asr.metrics.wer import CTCDecodingConfig from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecRNNTModel from nemo.collections.asr.models.confidence_ensemble import ConfidenceEnsembleModel -from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig, ConfidenceMethodConfig +from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig, ConfidenceMeasureConfig def get_model_config(model_class): @@ -117,12 +117,7 @@ def test_model_creation_2models(self, tmp_path, model_class0, model_class1): preserve_frame_confidence=True, exclude_blank=True, aggregation="mean", - method_cfg=ConfidenceMethodConfig( - name="entropy", - entropy_type="renui", - temperature=0.25, # this is not really temperature, but alpha, see https://arxiv.org/abs/2212.08703 - entropy_norm="lin", - ), + measure_cfg=ConfidenceMeasureConfig(name="entropy", entropy_type="renyi", alpha=0.25, entropy_norm="lin",), ) # just checking that no errors are raised when creating the model @@ -153,12 +148,7 @@ def test_model_creation_5models(self, tmp_path): preserve_frame_confidence=True, exclude_blank=True, aggregation="mean", - method_cfg=ConfidenceMethodConfig( - name="entropy", - entropy_type="renui", - temperature=0.25, # this is not really temperature, but alpha, see https://arxiv.org/abs/2212.08703 - entropy_norm="lin", - ), + measure_cfg=ConfidenceMeasureConfig(name="entropy", entropy_type="renyi", alpha=0.25, entropy_norm="lin",), ) # just checking that no errors are raised when creating the model diff --git a/tests/collections/nlp/test_gpt_eval.py b/tests/collections/nlp/test_gpt_eval.py index 0e64b989176f..fb3f9fda5ac3 100644 --- a/tests/collections/nlp/test_gpt_eval.py +++ b/tests/collections/nlp/test_gpt_eval.py @@ -78,6 +78,7 @@ def test_gpt_eval(self): "add_BOS": True, "all_probs": False, "compute_logprob": False, + "end_strings": ["<|endoftext|>"], } # test logprob diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 05b23e6afb1b..67f016b0c2af 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -54,8 +54,7 @@ def radtts_model(): model = RadTTSModel(cfg=cfg.model) app_state.is_model_being_restored = False model.eval() - model.export_config['enable_ragged_batches'] = True - model.export_config['enable_volume'] = True + model.set_export_config({'enable_ragged_batches': 'True', 'enable_volume': 'True'}) return model diff --git a/tutorials/asr/ASR_Confidence_Estimation.ipynb b/tutorials/asr/ASR_Confidence_Estimation.ipynb new file mode 100644 index 000000000000..2a1ad024a889 --- /dev/null +++ b/tutorials/asr/ASR_Confidence_Estimation.ipynb @@ -0,0 +1,1432 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "abe9913d", + "metadata": { + "id": "1a0f93c6" + }, + "outputs": [], + "source": [ + "BRANCH = 'main'\n", + "\n", + "\"\"\"\n", + "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run this cell to set up dependencies.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd26974d", + "metadata": { + "id": "ffdfe626" + }, + "outputs": [], + "source": [ + "import os\n", + "# either provide a path to local NeMo repository with NeMo already installed or git clone\n", + "\n", + "# option #1: local path to NeMo repo with NeMo already installed\n", + "NEMO_DIR_PATH = os.path.dirname(os.path.dirname(os.path.abspath('')))\n", + "is_colab = False\n", + "\n", + "# option #2: download NeMo repo\n", + "if 'google.colab' in str(get_ipython()) or not os.path.exists(os.path.join(NEMO_DIR_PATH, \"nemo\")):\n", + " ## Install dependencies\n", + " !apt-get install sox libsndfile1 ffmpeg\n", + "\n", + " !git clone -b $BRANCH https://github.com/NVIDIA/NeMo\n", + " %cd NeMo\n", + " !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", + " NEMO_DIR_PATH = os.path.abspath('')\n", + " is_colab = True\n", + "\n", + "import sys\n", + "sys.path.insert(0, NEMO_DIR_PATH)" + ] + }, + { + "cell_type": "markdown", + "id": "b3f35d50", + "metadata": { + "id": "bcc3e593" + }, + "source": [ + "# 1. Introduction to ASR confidence estimation\n", + "Confidence estimation is a crucial yet sometimes overlooked aspect of automatic speech recognition (ASR) systems. Confidence estimation for ASR is the process of estimating the rate of reliability of the output generated by an ASR system. For an output transcription, confidence estimation answers the question \"how accurate this transcription is\", or \"how likely this transcription is correct\".\n", + "\n", + "Confidence score is the result of confidence estimation. It lies in range from 0 to 1, where zero signals that the confidence estimator is completely unsure, and one indicates that the estimator is confident in the output. Confidence scores are often used to guide downstream processing in ASR applications. For example, in a voice dictation application, a low confidence score could trigger the system to ask the user to repeat the input or to suggest alternative transcriptions.\n", + "\n", + "There are several approaches to confidence estimation in ASR, including:\n", + "\n", + "1. Acoustic modeling-based methods: These methods use the acoustic model scores to estimate the confidence score. The acoustic model represents the relationship between the acoustic signal and the corresponding linguistic units, and the score reflects the similarity between the observed signal and the predicted model output. Here, the acoustic model can be the ASR model itself (non-trainable methods), or a trainable external estimator, accepting acoustic features or output probabilities and predicting confidence scores.\n", + "\n", + "2. Language modeling-based methods: These methods use the language model scores to estimate the confidence score. The language model represents the probability distribution of the sequence of words, and the score reflects the likelihood of the transcription given the language model. \n", + "\n", + "3. Combination methods: These methods combine the scores from both the acoustic and language models to estimate the confidence score. This approach can leverage the strengths of both models to achieve more accurate confidence scores.\n", + "\n", + "In this introductory tutorial we will cover only the non-trainable acoustic-based methods." + ] + }, + { + "cell_type": "markdown", + "id": "34e356bf", + "metadata": { + "id": "59100fb9" + }, + "source": [ + "## 1.1. Optional resources\n", + "This tutorial is self-contained, but if you want to dive deeper into the topic, you can check out these resources:\n", + "* Paper behind this tutorial: https://arxiv.org/abs/2212.08703\n", + "* Supplementary blog on how and why confidence estimation methods of this tutorial were developed: https://developer.nvidia.com/blog/entropy-based-methods-for-word-level-asr-confidence-estimation/" + ] + }, + { + "cell_type": "markdown", + "id": "9739cb35", + "metadata": { + "id": "cd7226c5" + }, + "source": [ + "# 2. Data Download\n", + "First, let's download audio and text data. Here we will use LibriSpeech *dev-other* and *test-other*." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46b2861b", + "metadata": { + "id": "fd542e62" + }, + "outputs": [], + "source": [ + "## create data directory and download an audio file\n", + "WORK_DIR = 'WORK_DIR'\n", + "DATA_DIR = WORK_DIR + '/DATA'\n", + "os.makedirs(DATA_DIR, exist_ok=True)\n", + "\n", + "print('downloading audio data...')\n", + "!python $NEMO_DIR_PATH/scripts/dataset_processing/get_librispeech_data.py --data_root=$DATA_DIR --data_set=test_other\n", + "!rm $DATA_DIR/test_other.tar.gz" + ] + }, + { + "cell_type": "markdown", + "id": "8ba5ad12", + "metadata": { + "id": "383eee71" + }, + "source": [ + "# 3. Confidence estimation example\n", + "Let's see how confidence scores can be obtained with NeMo models." + ] + }, + { + "cell_type": "markdown", + "id": "a95697fe", + "metadata": { + "id": "7c7c0170" + }, + "source": [ + "## 3.1. Helper functions\n", + "The following functions are to pretty-print confidence scores for word-level ASR hypotheses." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0bd12b7b", + "metadata": { + "id": "20cf0b38" + }, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "from termcolor import colored\n", + "from typing import List, Optional, Tuple, Union\n", + "\n", + "from IPython.display import Audio, HTML, Image, display\n", + "import numpy as np\n", + "import texterrors\n", + "\n", + "def get_detailed_wer_labels(ref: List[str], hyp: List[str], return_eps_padded_hyp: bool = False):\n", + " \"\"\"Get detailed WER labels, aligning reference with hypothesis.\n", + " \n", + " Possible WER labels:\n", + " - 'C' for Correct,\n", + " - 'I' for Insertion,\n", + " - 'D' for Deletion,\n", + " - 'S' for Substitution.\n", + "\n", + " Returns:\n", + " WER labels list.\n", + " [Optional] Epsilin-padded hypothesis if return_eps_padded_hyp set to True.\n", + " \"\"\"\n", + "\n", + " # Align reference and hypothesis using \"\"\n", + " aligned_ref, aligned_hyp = texterrors.align_texts(ref, hyp, False)[:-1]\n", + "\n", + " # Determine labels\n", + " labels = []\n", + " for r, h in zip(aligned_ref, aligned_hyp):\n", + " if r == h:\n", + " labels.append(\"C\")\n", + " elif r == \"\":\n", + " labels.append(\"I\")\n", + " elif h == \"\":\n", + " labels.append(\"D\")\n", + " else:\n", + " labels.append(\"S\")\n", + "\n", + " return labels if not return_eps_padded_hyp else labels, aligned_hyp\n", + "\n", + "\n", + "def fill_confidence_deletions(confidence_scores: List[float], labels: List[str], fill_value: float = 0.0):\n", + " \"\"\"Fill confidence scores list with the provided value for deletions.\n", + " Assumes that we have no natural confidence scores for deletions.\n", + " \n", + " Returns:\n", + " Confidence scores list with deletion scores.\n", + " \"\"\"\n", + "\n", + " assert len(confidence_scores) <= len(labels)\n", + "\n", + " # If the lengths of confidence_scores and labels are equal, then we assume that there are no deletions\n", + " if len(confidence_scores) == len(labels):\n", + " return confidence_scores\n", + "\n", + " # Insert fill_value into confidence_scores where label == \"D\"\n", + " new_confidence_scores = []\n", + " score_index = 0\n", + " for label in labels:\n", + " if label == \"D\":\n", + " new_confidence_scores.append(fill_value)\n", + " else:\n", + " new_confidence_scores.append(confidence_scores[score_index])\n", + " score_index += 1\n", + " return new_confidence_scores\n", + "\n", + "\n", + "def pretty_pad_word_labels(labels: List[str], words: List[str]):\n", + " \"\"\"Pad word labels with dash for pretty printing.\n", + " Expects labels and words to have the same length.\n", + " \n", + " Returns:\n", + " Padded labels list.\n", + " \"\"\"\n", + " \n", + " # Check that words and labels without 'D' have the same length\n", + " assert len(words) == len(labels)\n", + "\n", + " # Pad the labels with dashes to align them with the words\n", + " padded_labels = []\n", + " for word, label in zip(words, labels):\n", + " label_len = len(word)\n", + " left_padding = (label_len - 1) // 2\n", + " right_padding = label_len - left_padding - 1\n", + " padded_label = \"-\" * left_padding + label + \"-\" * right_padding\n", + " padded_labels.append(padded_label)\n", + "\n", + " return padded_labels\n", + "\n", + "\n", + "def _html_paint_word_grey(word: str, shade: str):\n", + " if shade == \"black\":\n", + " color = \"0,0,0\"\n", + " elif shade == \"grey\":\n", + " color = \"150,150,150\"\n", + " elif shade == \"light_grey\":\n", + " color = \"200,200,200\"\n", + " else:\n", + " raise ValueError(\n", + " f\"`shade` has to be one of the following: `black`, `grey`, `light_grey`. Provided: `{shade}`\"\n", + " )\n", + " return f'{word}'\n", + "\n", + "\n", + "def pretty_print_transcript_with_confidence(\n", + " transcript: str,\n", + " confidence_scores: List[float],\n", + " threshold: float,\n", + " reference: Optional[str] = None,\n", + " terminal_width: int = 120,\n", + " html: bool = False,\n", + "):\n", + " if html:\n", + " shade_if_low_confidence = lambda x, y: _html_paint_word_grey(x, 'light_grey' if y < threshold else 'black')\n", + " new_line_mark = \"
\"\n", + " pretty_print = lambda x: display(HTML(\"\" + new_line_mark.join(x) + \"\"))\n", + " else:\n", + " shade_if_low_confidence = lambda x, y: colored(x, 'light_grey') if y < threshold else x\n", + " new_line_mark = \"\\n\"\n", + " pretty_print = lambda x: print(new_line_mark.join(x))\n", + " with_labels = reference is not None\n", + " transcript_list = transcript.split()\n", + " output_lines = []\n", + " if with_labels:\n", + " reference_list = reference.split()\n", + " labels, eps_padded_hyp = get_detailed_wer_labels(reference_list, transcript_list, True)\n", + " padded_labels = pretty_pad_word_labels(labels, eps_padded_hyp)\n", + " current_line_len = 0\n", + " current_word_line = \"\"\n", + " current_label_line = \"\"\n", + " for word, label, padded_label, score in zip(\n", + " eps_padded_hyp, labels, padded_labels, fill_confidence_deletions(confidence_scores, labels)\n", + " ):\n", + " word_len = len(word)\n", + " # shield angle brakets for \n", + " if html and word == \"\":\n", + " word = \"<eps>\"\n", + " if current_line_len + word_len + 1 <= terminal_width:\n", + " if current_line_len > 0:\n", + " current_line_len += 1\n", + " current_word_line += \" \"\n", + " current_label_line += \"-\"\n", + " current_line_len += word_len\n", + " current_word_line += shade_if_low_confidence(word, score)\n", + " current_label_line += padded_label\n", + " else:\n", + " output_lines.append(current_word_line + new_line_mark + current_label_line)\n", + " current_line_len = word_len\n", + " current_word_line = shade_if_low_confidence(word, score)\n", + " current_label_line = padded_label\n", + " if current_word_line:\n", + " output_lines.append(current_word_line + new_line_mark + current_label_line)\n", + " else:\n", + " current_line_len = 0\n", + " current_word_line = \"\"\n", + " for word, score in zip(transcript_list, confidence_scores):\n", + " word_len = len(word)\n", + " # shield angle brakets for \n", + " if html and word == \"\":\n", + " word = \"<eps>\"\n", + " if current_line_len + word_len + 1 <= terminal_width:\n", + " if current_line_len > 0:\n", + " current_line_len += 1\n", + " current_word_line += \" \"\n", + " current_line_len += word_len\n", + " current_word_line += shade_if_low_confidence(word, score)\n", + " else:\n", + " output_lines.append(current_word_line)\n", + " current_line_len = word_len\n", + " current_word_line = shade_if_low_confidence(word, score)\n", + " if current_word_line:\n", + " output_lines.append(current_word_line)\n", + "\n", + " pretty_print(output_lines)" + ] + }, + { + "cell_type": "markdown", + "id": "ed997bfd", + "metadata": { + "id": "dec57a27" + }, + "source": [ + "## 3.2. Data and model loading\n", + "This tutorial uses CTC and RNN-T Conformer models trained on LibriSpeech.\n", + "\n", + "You can try to use other pre-trained models as well." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70c1a27a", + "metadata": { + "id": "b66c60a3" + }, + "outputs": [], + "source": [ + "from dataclasses import dataclass\n", + "from omegaconf import DictConfig, OmegaConf\n", + "\n", + "from nemo.collections.asr.models import ASRModel\n", + "\n", + "def load_model(name: str):\n", + " \"\"\"Load a pre-trained model.\n", + "\n", + " Args:\n", + " name: Pre-trained model name.\n", + " Reserved names:\n", + " - 'ctc' for 'stt_en_conformer_ctc_large_ls'\n", + " - 'rnnt' for 'stt_en_conformer_transducer_large_ls'\n", + "\n", + " Returns:\n", + " A model loaded into GPU with .eval() mode set.\n", + " \"\"\"\n", + " if name == \"ctc\":\n", + " name = \"stt_en_conformer_ctc_large_ls\"\n", + " elif name == \"rnnt\":\n", + " name = \"stt_en_conformer_transducer_large_ls\"\n", + "\n", + " model = ASRModel.from_pretrained(model_name=name, map_location=\"cuda:0\")\n", + " model.eval()\n", + "\n", + " return model\n", + "\n", + "@dataclass\n", + "class TestSet:\n", + " filepaths: List[str]\n", + " reference_texts: List[str]\n", + " durations: List[float]\n", + "\n", + "def load_data(manifest_path: str):\n", + " filepaths = []\n", + " reference_texts = []\n", + " durations = []\n", + " with open(manifest_path, \"r\") as f:\n", + " for line in f:\n", + " item = json.loads(line)\n", + " audio_file = item[\"audio_filepath\"]\n", + " filepaths.append(str(audio_file))\n", + " text = item[\"text\"]\n", + " reference_texts.append(text)\n", + " durations.append(float(item[\"duration\"]))\n", + " return TestSet(filepaths, reference_texts, durations)\n", + "\n", + "TEST_MANIFESTS = {\n", + " \"test_other\": DATA_DIR + \"/test_other.json\",\n", + "}\n", + "\n", + "\n", + "# Load data\n", + "test_sets = {manifest: load_data(path) for manifest, path in TEST_MANIFESTS.items()}\n", + "\n", + "# Load model\n", + "is_rnnt = False\n", + "# is_rnnt = True\n", + "\n", + "model = load_model(\"rnnt\" if is_rnnt else \"ctc\")" + ] + }, + { + "cell_type": "markdown", + "id": "9c5db700", + "metadata": { + "id": "88c3d7ee" + }, + "source": [ + "## 3.3. Setting up confidence estimation\n", + "To set up confidence estimation for NeMo ASR models, you need to:\n", + "1. Initialize _ConfidenceConfig_\n", + "2. Put the created _ConfidenceConfig_ into the model decoding config.\n", + "\n", + "The folloving cell contains an example of _ConfidenceConfig_ initialization and updating the the model's decoding config.\n", + "\n", + "For the _ConfidenceConfig_ there are also listed possible values for its parameters.\n", + "\n", + "Note that only `strategy=\"greedy\"` (or `greedy_batch` for RNN-T) supports computing confidence scores." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d3e8c11", + "metadata": { + "id": "078005f1" + }, + "outputs": [], + "source": [ + "from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig\n", + "from nemo.collections.asr.metrics.wer import CTCDecodingConfig\n", + "from nemo.collections.asr.parts.utils.asr_confidence_utils import (\n", + " ConfidenceConfig,\n", + " ConfidenceConstants,\n", + " ConfidenceMeasureConfig,\n", + " ConfidenceMeasureConstants,\n", + ")\n", + "from nemo.collections.asr.parts.utils.asr_confidence_benchmarking_utils import (\n", + " apply_confidence_parameters,\n", + " get_correct_marks,\n", + " get_token_targets_with_confidence,\n", + " get_word_targets_with_confidence,\n", + ")\n", + "\n", + "\n", + "# List allowed options for ConfidenceMeasureConfig and ConfidenceConfig\n", + "print(f\"Allowed options for ConfidenceMeasureConfig: {ConfidenceMeasureConstants.print()}\\n\")\n", + "print(f\"Allowed options for ConfidenceConfig: {ConfidenceConstants.print()}\\n\")\n", + "\n", + "# Initialize ConfidenceConfig and ConfidenceMeasureConfig\n", + "confidence_cfg = ConfidenceConfig(\n", + " preserve_frame_confidence=True, # Internally set to true if preserve_token_confidence == True\n", + " # or preserve_word_confidence == True\n", + " preserve_token_confidence=True, # Internally set to true if preserve_word_confidence == True\n", + " preserve_word_confidence=True,\n", + " aggregation=\"prod\", # How to aggregate frame scores to token scores and token scores to word scores\n", + " exclude_blank=False, # If true, only non-blank emissions contribute to confidence scores\n", + " measure_cfg=ConfidenceMeasureConfig( # Config for per-frame scores calculation (before aggregation)\n", + " name=\"max_prob\", # Or \"entropy\" (default), which usually works better\n", + " entropy_type=\"gibbs\", # Used only for name == \"entropy\". Recommended: \"tsallis\" (default) or \"renyi\"\n", + " alpha=0.5, # Low values (<1) increase sensitivity, high values decrease sensitivity\n", + " entropy_norm=\"lin\" # How to normalize (map to [0,1]) entropy. Default: \"exp\"\n", + " )\n", + ")\n", + "\n", + "# Alternalively, look at ConfidenceConfig's docstring\n", + "print(f\"More info on ConfidenceConfig here:\\n{ConfidenceConfig().__doc__}\\n\")\n", + "\n", + "# Put the created ConfidenceConfig into the model decoding config via .change_decoding_strategy()\n", + "model.change_decoding_strategy(\n", + " RNNTDecodingConfig(fused_batch_size=-1, strategy=\"greedy_batch\", confidence_cfg=confidence_cfg)\n", + " if is_rnnt\n", + " else CTCDecodingConfig(confidence_cfg=confidence_cfg)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "04581687", + "metadata": { + "id": "efe0baea" + }, + "source": [ + "## 3.4. Decode test set and get transcriptions with confidence scores\n", + "Let's transcribe Librispeech _test-other_ and see what confidence scores are inside." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5f92257", + "metadata": { + "id": "ccd8d0de" + }, + "outputs": [], + "source": [ + "current_test_set = test_sets[\"test_other\"]\n", + "transcriptions = model.transcribe(paths2audio_files=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", + "if is_rnnt:\n", + " transcriptions = transcriptions[0]" + ] + }, + { + "cell_type": "markdown", + "id": "ca282352", + "metadata": { + "id": "0500514e" + }, + "source": [ + "For a transcribed hypothesis, there can be `frame_confidence` and aggregated from them `token_confidence` and `word_confidence`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18663384", + "metadata": { + "id": "98035fd2" + }, + "outputs": [], + "source": [ + "tran = transcriptions[0]\n", + "print(\n", + " f\"\"\" Recognized text: `{tran.text}`\\n\n", + " Word confidence: {[round(c, 3) for c in tran.word_confidence]}\\n\n", + " Token confidence: {[round(c, 3) for c in tran.token_confidence]}\\n\n", + " Frame confidence: {[([round(cc, 3) for cc in c] if is_rnnt else round(c, 3)) for c in tran.frame_confidence]}\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "783e9e2a", + "metadata": { + "id": "9613bfc1" + }, + "source": [ + "Now let's draw the recognition results highlighted according to their confidence scores.\n", + "\n", + "There are four options: plain text and HTML with or without WER labels." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "642fe059", + "metadata": { + "id": "a83295ff" + }, + "outputs": [], + "source": [ + "from nemo.collections.asr.metrics.wer import word_error_rate, word_error_rate_detail, word_error_rate_per_utt\n", + "\n", + "def show_dataset_with_confidence(\n", + " indices,\n", + " transcriptions,\n", + " test_set,\n", + " threshold,\n", + " filepaths=None,\n", + " html_show=False,\n", + " min_dur_to_show=0.0,\n", + " utt_to_show=10\n", + "):\n", + " utt_shown = 0\n", + " for i, _ in indices:\n", + " if utt_shown >= utt_to_show:\n", + " break\n", + " if test_set.durations[i] >= min_dur_to_show:\n", + " print(\"=\"*120)\n", + " hyp = transcriptions[i].text\n", + " scores = transcriptions[i].word_confidence\n", + " ref = test_set.reference_texts[i]\n", + " pretty_print_transcript_with_confidence(hyp, scores, threshold, ref, html=html_show)\n", + " if filepaths is not None:\n", + " display(Audio(filepaths[i]))\n", + " utt_shown += 1\n", + "\n", + "\n", + "# you can play with these parameters\n", + "threshold = 0.52\n", + "# in colab, you may want to use `html_show = True` as non-html colorion displayed incorrectly in colab\n", + "html_show = is_colab\n", + "min_dur_to_show = 4.0\n", + "utt_to_show = 5\n", + "\n", + "wer_per_utt, avg_wer = word_error_rate_per_utt([h.text for h in transcriptions], current_test_set.reference_texts)\n", + "sorted_wer_indices = sorted(enumerate(wer_per_utt), key=lambda x: x[1])[::-1]\n", + "\n", + "show_dataset_with_confidence(\n", + " indices=sorted_wer_indices,\n", + " transcriptions=transcriptions,\n", + " test_set=current_test_set,\n", + " threshold=threshold,\n", + " filepaths=current_test_set.filepaths,\n", + " html_show=html_show,\n", + " min_dur_to_show=min_dur_to_show,\n", + " utt_to_show=utt_to_show\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9468ad3e", + "metadata": { + "id": "dbfcb2da" + }, + "source": [ + "## 3.5. Confidence metrics\n", + "\n", + "There are several metrics to evaluate the effectiveness of a confidence estimation method. Some of them consider confidence estimation as a binary classification task. Other measure how close the correct word confidence scores are to $1.0$ and the incorrect word scores are to $0.0$.\n", + "\n", + "Some of them are:\n", + "1. Area Under the Receiver Operating Characteristics Curve ($\\mathrm{AUC}_\\mathrm{ROC}$): class separability metric.\n", + "2. Area Under the Precision-Recall Curve ($\\mathrm{AUC}_\\mathrm{PR}$): how well the correct words are detected.\n", + "3. Area Under the Negative Predictive Value vs. True Negative Rate Curve ($\\mathrm{AUC}_\\mathrm{NT}$): how well the incorrect words are detected ($\\mathrm{AUC}_\\mathrm{PR}$ in which errors are treated as positives).\n", + "4. Normalized Cross Entropy ($\\mathrm{NCE}$): how close of confidence for correct predictions to $1.0$ and of incorrect predictions to $0.0$. It ranges from $-\\infty$ to $1.0$, with negative scores indicating that the confidence method performs worse than the setting confidence score to $1-\\mathrm{WER}$. This metric is also known as Normalized Mutual Information.\n", + "5. Expected Calibration Error ($\\mathrm{ECE}$): a weighted average over the absolute accuracy/confidence difference. It ranges from $0.0$ to $1.0$ with the best value $0.0$.\n", + "\n", + "Metrics based on the Youden's curve (see https://en.wikipedia.org/wiki/Youden%27s_J_statistic) can also be condsidered. They are:\n", + "1. Area Under the Youden's curve ($\\mathrm{AUC}_\\mathrm{YC}$): the rate of the effective threshold range (i.e. the adjustability or responsiveness). It ranges from $0.0$ to $1.0$ with the best value $0.5$.\n", + "2. Maximum of the Youden's curve $\\mathrm{MAX}_\\mathrm{YC}$: the optimal $\\mathrm{TNR}$ vs. $\\mathrm{FNR}$ tradeoff. It's unnormalized version can be used as a criterion for selecting the optimal $\\tau$. It ranges from $0.0$ to $1.0$ with the best value $1.0$.\n", + "3. The standard deviation of the Youden's curve values ($\\mathrm{STD}_\\mathrm{YC}$): indicates that $\\mathrm{TNR}$ and $\\mathrm{FNR}$ increase at different rates (viz. $\\mathrm{TNR}$ grows faster) as the $\\tau$ increases. It ranges from $0.0$ to $0.5$ with the best value around $0.25$.\n", + "\n", + "When selecting/tuning a confidence method, it is recommended to maximize $\\mathrm{AUC}_\\mathrm{ROC}$ first as this is the main mectic of confidence estimation quality. Then, for overconfident models, maximizing $\\mathrm{AUC}_\\mathrm{NT}$ should take precedence over $\\mathrm{AUC}_\\mathrm{PR}$. Finally, a trade-off between $\\mathrm{NCE}$/$\\mathrm{ECE}$ and the family of $\\mathrm{YC}$ metrics considered as a compromise between formal correctness and controllability.\n", + "\n", + "Let's see how well our confidence performs according to the metrcis above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b0fa793", + "metadata": { + "id": "5d152775" + }, + "outputs": [], + "source": [ + "from nemo.collections.asr.parts.utils.confidence_metrics import (\n", + " auc_nt,\n", + " auc_pr,\n", + " auc_roc,\n", + " auc_yc,\n", + " ece,\n", + " nce,\n", + " save_confidence_hist,\n", + " save_custom_confidence_curve,\n", + " save_nt_curve,\n", + " save_pr_curve,\n", + " save_roc_curve,\n", + ")\n", + "\n", + "\n", + "targets_with_confidence = [get_word_targets_with_confidence(tran) for tran in transcriptions]\n", + "correct_marks = [get_correct_marks(r.split(), h.words) for r, h in zip(current_test_set.reference_texts, transcriptions)]\n", + "\n", + "y_true, y_score = np.array(\n", + " [[f, p[1]] for cm, twc in zip(correct_marks, targets_with_confidence) for f, p in zip(cm, twc)]\n", + ").T\n", + "\n", + "\n", + "# output scheme: yc.mean(), yc.max(), yc.std() or yc.mean(), yc.max(), yc.std(), (thresholds, yc)\n", + "result_yc = auc_yc(y_true, y_score, return_std_maximum=True, return_curve=True)\n", + "# output scheme: ece or ece, (thresholds, ece_curve)\n", + "results_ece = ece(y_true, y_score, return_curve=True)\n", + "results = [\n", + " auc_roc(y_true, y_score),\n", + " auc_pr(y_true, y_score),\n", + " auc_nt(y_true, y_score),\n", + " nce(y_true, y_score),\n", + " results_ece[0],\n", + "] + list(result_yc[:3])\n", + "\n", + "print(\n", + " f\"\"\" AUC_ROC:\\t{results[0]:.5f}\n", + " AUC_PR:\\t{results[1]:.5f}\n", + " AUC_NT:\\t{results[2]:.5f}\n", + " NCE:\\t{results[3]:.5f}\n", + " ECE:\\t{results[4]:.5f}\n", + " AUC_YC:\\t{results[5]:.5f}\n", + " MAX_YC:\\t{results[7]:.5f}\n", + " STD_YC:\\t{results[6]:.5f}\n", + " \"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0c3f6299", + "metadata": { + "id": "4159034d" + }, + "source": [ + "Confidence metrics for the maximum probability confidence are not that great.\n", + "\n", + "Let's re-run and benchmark confidence estimation with the default confidence estimator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c0e3a9f", + "metadata": { + "id": "d2e16f5f" + }, + "outputs": [], + "source": [ + "confidence_cfg = ConfidenceConfig(\n", + " preserve_word_confidence=True,\n", + " preserve_token_confidence=True,\n", + ")\n", + "\n", + "model.change_decoding_strategy(\n", + " RNNTDecodingConfig(fused_batch_size=-1, strategy=\"greedy_batch\", confidence_cfg=confidence_cfg)\n", + " if is_rnnt\n", + " else CTCDecodingConfig(confidence_cfg=confidence_cfg)\n", + ")\n", + "\n", + "transcriptions = model.transcribe(paths2audio_files=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", + "if is_rnnt:\n", + " transcriptions = transcriptions[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8f1cc77", + "metadata": { + "id": "6201ea4d" + }, + "outputs": [], + "source": [ + "targets_with_confidence = [get_word_targets_with_confidence(tran) for tran in transcriptions]\n", + "correct_marks = [get_correct_marks(r.split(), h.words) for r, h in zip(current_test_set.reference_texts, transcriptions)]\n", + "\n", + "y_true, y_score = np.array(\n", + " [[f, p[1]] for cm, twc in zip(correct_marks, targets_with_confidence) for f, p in zip(cm, twc)]\n", + ").T\n", + "\n", + "result_yc = auc_yc(y_true, y_score, return_std_maximum=True, return_curve=True)\n", + "results_ece = ece(y_true, y_score, return_curve=True)\n", + "results = [\n", + " auc_roc(y_true, y_score),\n", + " auc_pr(y_true, y_score),\n", + " auc_nt(y_true, y_score),\n", + " nce(y_true, y_score),\n", + " results_ece[0],\n", + "] + list(result_yc[:3])\n", + "\n", + "print(\n", + " f\"\"\" AUC_ROC:\\t{results[0]:.5f}\n", + " AUC_PR:\\t{results[1]:.5f}\n", + " AUC_NT:\\t{results[2]:.5f}\n", + " NCE:\\t{results[3]:.5f}\n", + " ECE:\\t{results[4]:.5f}\n", + " AUC_YC:\\t{results[5]:.5f}\n", + " MAX_YC:\\t{results[7]:.5f}\n", + " STD_YC:\\t{results[6]:.5f}\n", + " \"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9ab2b130", + "metadata": { + "id": "498e03d0" + }, + "source": [ + "Note that despite the overall improvement, $NCE$ and $ECE$ have gotten worse. This is due to class imbalance caused by low WER." + ] + }, + { + "cell_type": "markdown", + "id": "f96cea04", + "metadata": { + "id": "45856cba" + }, + "source": [ + "Now, let's draw $\\mathrm{ROC}$ as well as histograms of correctly and incorrectly recognized words." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81844713", + "metadata": { + "id": "ff049043" + }, + "outputs": [], + "source": [ + "from tempfile import TemporaryDirectory\n", + "\n", + "\n", + "plot_dir = TemporaryDirectory()\n", + "os.makedirs(plot_dir.name, exist_ok=True)\n", + "\n", + "mask_correct = y_true == 1\n", + "y_score_correct = y_score[mask_correct]\n", + "y_score_incorrect = y_score[~mask_correct]\n", + "\n", + "# histogram of the correct distribution\n", + "save_confidence_hist(y_score_correct, plot_dir.name, \"hist_correct\")\n", + "# histogram of the incorrect distribution\n", + "save_confidence_hist(y_score_incorrect, plot_dir.name, \"hist_incorrect\")\n", + "# AUC-ROC curve\n", + "save_roc_curve(y_true, y_score, plot_dir.name, \"roc\")\n", + "\n", + "\n", + "display(\n", + " Image(filename=os.path.join(plot_dir.name, \"hist_correct.png\"), retina=True),\n", + " Image(filename=os.path.join(plot_dir.name, \"hist_incorrect.png\"), retina=True),\n", + " Image(filename=os.path.join(plot_dir.name, \"roc.png\"), retina=True),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "841a27ca", + "metadata": {}, + "source": [ + "Optionally, you can look at curves for other metrics ($\\mathrm{PR}$, $\\mathrm{NT}$, $\\mathrm{ECE}$, and $\\mathrm{YC}$)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6164e8f0", + "metadata": {}, + "outputs": [], + "source": [ + "# AUC-PR curve\n", + "save_pr_curve(y_true, y_score, plot_dir.name, \"pr\")\n", + "# AUC-NT curve\n", + "save_nt_curve(y_true, y_score, plot_dir.name, \"nt\")\n", + "# ECE curve\n", + "ece_thresholds, ece_values = results_ece[-1]\n", + "ece_values /= max(ece_values)\n", + "save_custom_confidence_curve(\n", + " ece_thresholds, ece_values, plot_dir.name, \"ece\", \"Threshold\", \"|Accuracy − Confidence score|\"\n", + ")\n", + "# AUC-YC curve\n", + "yc_thresholds, yc_values = result_yc[-1]\n", + "save_custom_confidence_curve(\n", + " yc_thresholds, yc_values, plot_dir.name, \"yc\", \"Threshold\", \"True positive rate − False Positive Rate\"\n", + ")\n", + "\n", + "\n", + "display(\n", + " Image(filename=os.path.join(plot_dir.name, \"pr.png\"), retina=True),\n", + " Image(filename=os.path.join(plot_dir.name, \"nt.png\"), retina=True),\n", + " Image(filename=os.path.join(plot_dir.name, \"ece.png\"), retina=True),\n", + " Image(filename=os.path.join(plot_dir.name, \"yc.png\"), retina=True),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9f63a172", + "metadata": { + "id": "ad78630a" + }, + "source": [ + "You can use `scripts/speech_recognition/confidence/benchmark_asr_confidence.py` to find optimal confidence hyperparameters." + ] + }, + { + "cell_type": "markdown", + "id": "1d9a822d", + "metadata": { + "id": "15e25521" + }, + "source": [ + "# 4. Confidence applications" + ] + }, + { + "cell_type": "markdown", + "id": "8ab6e666", + "metadata": { + "id": "dbb82877" + }, + "source": [ + "## 4.1. Small WER improvenent\n", + "\n", + "Good confidence scores can slightly reduce WER by removing low confidence words from recognition results.\n", + "\n", + "Consider the following example." + ] + }, + { + "cell_type": "markdown", + "id": "4038863c", + "metadata": { + "id": "02eb4e1f" + }, + "source": [ + "Let's look at the detailed WER of the transcribed test set before and after removing words with low confidence score." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "204d36ac", + "metadata": { + "id": "fdf790b5" + }, + "outputs": [], + "source": [ + "drop_low_confidence_words = lambda x, y, z: \" \".join([xx for xx, yy in zip(x.split(), y) if yy >= z])\n", + "\n", + "\n", + "threshold = 0.001\n", + "\n", + "wer_initial = word_error_rate_detail([h.text for h in transcriptions], current_test_set.reference_texts)\n", + "print(\n", + " f\"\"\"WER detail before removing low confidence words:\n", + " WER:\\t{wer_initial[0]:.5f}\n", + " INS_rate:\\t{wer_initial[2]:.5f}\n", + " DEL_rate:\\t{wer_initial[3]:.5f}\n", + " SUB_rate:\\t{wer_initial[4]:.5f}\"\"\"\n", + ")\n", + "\n", + "wer_conf_dropped = word_error_rate_detail(\n", + " [drop_low_confidence_words(hyp.text, hyp.word_confidence, threshold) for hyp in transcriptions],\n", + " current_test_set.reference_texts,\n", + ")\n", + "print(\n", + " f\"\"\"WER detail after removing low confidence words:\n", + " WER:\\t{wer_conf_dropped[0]:.5f}\n", + " INS_rate:\\t{wer_conf_dropped[2]:.5f}\n", + " DEL_rate:\\t{wer_conf_dropped[3]:.5f}\n", + " SUB_rate:\\t{wer_conf_dropped[4]:.5f}\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4f153cdd", + "metadata": { + "id": "28ac85b1" + }, + "source": [ + "You can see that with the right (in this example, extremely low) `threshold` can reduce WER by a tiny bit, reducing insertions and substitutions yet increasing deletions.\n", + "\n", + "Now let's see how to find the optimal threshold.\n", + "\n", + "The most commonly used method for automatically determining the optimal cutoff threshold is taking the value which delivers the maximum of the unnormalized Youden's curve. This method allows you to remove the largest number of incorrect entities, sacrificing the minimum number of correct entities.\n", + "\n", + "However, the unnormalized $\\mathrm{MAX}_\\mathrm{YC}$ method does not work well for the purpose of the WER reduction. Let's compare this method to explicitly minimizing WER with respect to a threshold." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19147b4a", + "metadata": { + "id": "9b81e449" + }, + "outputs": [], + "source": [ + "from joblib import Parallel, delayed\n", + "from multiprocessing import cpu_count\n", + "from tqdm.notebook import tqdm\n", + "\n", + "def max_unnnormalized_yc(\n", + " y_true: Union[List[int], np.ndarray],\n", + " y_score: Union[List[float], np.ndarray],\n", + " n_bins: int = 100,\n", + " start: float = 0.0,\n", + " stop: float = 1.0,\n", + "):\n", + " \"\"\"Calculate the maximum of the unnormalized Youden's curve.\n", + " \"\"\"\n", + " y_true = np.array(y_true)\n", + " y_score = np.array(y_score)\n", + " thresholds = np.linspace(start, stop, n_bins + 1)\n", + " assert len(y_true) == len(y_score)\n", + " assert np.all(y_true >= 0) and np.all(y_true <= 1)\n", + " if np.all(y_true == 0) or np.all(y_true == 1):\n", + " return 0.0, 0.0\n", + " mask_correct = y_true == 1\n", + " y_score_correct = y_score[mask_correct]\n", + " y_score_incorrect = y_score[~mask_correct]\n", + " unnnormalized_yc = []\n", + " for threshold in thresholds:\n", + " tn = len((y_score_incorrect < threshold).nonzero()[0])\n", + " fn = len((y_score_correct < threshold).nonzero()[0])\n", + " unnnormalized_yc.append((threshold, tn - fn))\n", + " return max(unnnormalized_yc, key=lambda x: x[1])[0]\n", + "\n", + "\n", + "def min_wer(ref: List[str], transcriptions, n_bins: int = 100, start: float = 0.0, stop: float = 1.0):\n", + " \"\"\"Find the threshold value that delivers the minimum WER.\n", + " \"\"\"\n", + " thresholds = np.linspace(start, stop, n_bins + 1)\n", + " hyp = [(hyp.text, hyp.word_confidence) for hyp in transcriptions]\n", + " _get_wer = lambda x, y, z: (x, word_error_rate_detail([drop_low_confidence_words(yy[0], yy[1], x) for yy in y], z)[0])\n", + " wers = Parallel(n_jobs=cpu_count())(delayed(_get_wer)(threshold, hyp, ref) for threshold in tqdm(thresholds))\n", + " return min(wers, key=lambda x: x[1])\n", + "\n", + "\n", + "targets_with_confidence = [get_word_targets_with_confidence(tran) for tran in transcriptions]\n", + "correct_marks = [\n", + " get_correct_marks(r.split(), h.words) for r, h in zip(current_test_set.reference_texts, transcriptions)\n", + "]\n", + "y_true, y_score = np.array(\n", + " [[f, p[1]] for cm, twc in zip(correct_marks, targets_with_confidence) for f, p in zip(cm, twc)]\n", + ").T\n", + "\n", + "threshold_yc = max_unnnormalized_yc(y_true, y_score)\n", + "yc_wer_value = word_error_rate(\n", + " [drop_low_confidence_words(hyp.text, hyp.word_confidence, threshold_yc) for hyp in transcriptions],\n", + " current_test_set.reference_texts,\n", + ")\n", + "threshold_min_wer, min_wer_value = min_wer(current_test_set.reference_texts, transcriptions, stop=0.1)\n", + "\n", + "print(\n", + " f\"\"\" Initial WER: {wer_initial[0]:.5f}\n", + " Optimal threshold and WER based on the Youden's curve: {threshold_yc}, {yc_wer_value:.5f}\n", + " Optimal threshold for the minimum WER: {threshold_min_wer}, {min_wer_value:.5f}\n", + " \"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "425d010e", + "metadata": { + "id": "3b278d2d" + }, + "source": [ + "As you can see, the optimal cutoff threshold as the maximum of the Youden's curve makes WER significantly worse, and the optimal threshold for the minimum WER is near zero.\n", + "\n", + "Let's use a different confidence estimation setup to see if we can improve WER at least a bit further." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d080686", + "metadata": { + "id": "39f72c78" + }, + "outputs": [], + "source": [ + "confidence_cfg = ConfidenceConfig(\n", + " preserve_word_confidence=True,\n", + " preserve_token_confidence=True,\n", + " aggregation=\"min\",\n", + " measure_cfg=DictConfig({\"entropy_type\": \"tsallis\", \"alpha\": 1.5, \"entropy_norm\": \"lin\"}),\n", + ")\n", + "\n", + "model.change_decoding_strategy(\n", + " RNNTDecodingConfig(fused_batch_size=-1, strategy=\"greedy_batch\", confidence_cfg=confidence_cfg)\n", + " if is_rnnt\n", + " else CTCDecodingConfig(confidence_cfg=confidence_cfg)\n", + ")\n", + "\n", + "transcriptions = model.transcribe(paths2audio_files=current_test_set.filepaths, batch_size=16, return_hypotheses=True, num_workers=4)\n", + "if is_rnnt:\n", + " transcriptions = transcriptions[0]\n", + "\n", + "threshold_min_wer, min_wer_value = min_wer(current_test_set.reference_texts, transcriptions)\n", + "\n", + "print(\n", + " f\"\"\" Initial WER: {wer_initial[0]:.5f}\n", + " Optimal threshold for the minimum WER: {threshold_min_wer}, {min_wer_value:.5f}\n", + " \"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e3c9cc02", + "metadata": { + "id": "e00581b1" + }, + "source": [ + "Overall, such an improvement in WER is too small to be considered. However, this opens up the possibility of improving WER through the use of more accurate confidence estimation methods." + ] + }, + { + "cell_type": "markdown", + "id": "694d1752", + "metadata": { + "id": "f9f89665" + }, + "source": [ + "## 4.2. Reducing hallucinations with confidence scores\n", + "\n", + "One common application of confidence scores is the removal of recognition hallucinations.\n", + "\n", + "Let's see how this can be done." + ] + }, + { + "cell_type": "markdown", + "id": "98a1ef83", + "metadata": { + "id": "c1c28379" + }, + "source": [ + "Firstly, let's obtain a dataset on which the ASR model can hallucinate.\n", + "\n", + "Here we make it from the librosa examples, reversing them and convolving with each other." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f12a5041", + "metadata": { + "id": "3b0a0b4c" + }, + "outputs": [], + "source": [ + "from itertools import combinations\n", + "import json\n", + "import librosa\n", + "import soundfile as sf\n", + "\n", + "def cyclic_sum(x, y):\n", + " if x.shape[0] < y.shape[0]:\n", + " x, y = y, x\n", + " if x.shape[0] > y.shape[0]:\n", + " y = np.take(y, range(0, x.shape[0]), mode='wrap')\n", + " return x + y\n", + "\n", + "def generate_noise_examples(example_list: List[str], save_dir: str, samplerate: int = 16000):\n", + " \"\"\"Generate noise examples with librosa.\n", + " It loads the selected example, inverts and perturbs them with each other.\n", + "\n", + " Returns:\n", + " A manifest with the noise wavs.\n", + " \"\"\"\n", + " samples = {ex: librosa.core.load(librosa.util.example(key=ex, hq=True), sr=samplerate)[0] \n", + " for ex in example_list}\n", + " noise_samples = {\"_\".join([left, right]): cyclic_sum(samples[left][::-1], samples[right][::-1]) \n", + " for left, right in combinations(samples.keys(), 2)}\n", + "\n", + " os.makedirs(save_dir, exist_ok=True)\n", + " manifest = os.path.join(save_dir, \"manifest.json\")\n", + " with open(manifest, \"tw\", encoding=\"utf-8\") as fout:\n", + " for k, v in noise_samples.items():\n", + " audio_path = os.path.join(save_dir, f\"{k}.wav\")\n", + " sf.write(audio_path, v, samplerate=samplerate)\n", + " metadata = {\n", + " \"audio_filepath\": audio_path,\n", + " \"duration\": librosa.core.get_duration(y=v, sr=samplerate),\n", + " \"label\": \"noise\",\n", + " \"text\": \"_\"\n", + " }\n", + " json.dump(metadata, fout)\n", + " fout.write('\\n')\n", + "\n", + " return manifest\n", + "\n", + "librosa_list_examples = ['brahms',\n", + " 'choice',\n", + " 'fishin',\n", + " 'humpback',\n", + " 'libri1',\n", + " 'libri2',\n", + " 'libri3',\n", + " 'nutcracker',\n", + " 'pistachio',\n", + " 'robin',\n", + " 'sweetwaltz',\n", + " 'trumpet',\n", + " 'vibeace']\n", + "sr = 16000\n", + "\n", + "noise_dir = os.path.join(DATA_DIR, \"noise\")\n", + "noise_manifest = generate_noise_examples(librosa_list_examples, noise_dir, sr)" + ] + }, + { + "cell_type": "markdown", + "id": "f28da61f", + "metadata": {}, + "source": [ + "The original examples contain speech, music, or noise. The resulring audio recordings are considered to contain no recognizable speech.\n", + "\n", + "You can listen to an example of the audios." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b4e7007", + "metadata": {}, + "outputs": [], + "source": [ + "noise_data = load_data(noise_manifest)\n", + "\n", + "display(Audio(noise_data.filepaths[0]))" + ] + }, + { + "cell_type": "markdown", + "id": "1db80ae4", + "metadata": { + "id": "f7f9ddca" + }, + "source": [ + "Now let's transcribe our new data, setting the default confidence estimator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a872926", + "metadata": { + "id": "60f39094" + }, + "outputs": [], + "source": [ + "confidence_cfg = ConfidenceConfig(\n", + " preserve_word_confidence=True,\n", + " preserve_token_confidence=True,\n", + ")\n", + "\n", + "model.change_decoding_strategy(\n", + " RNNTDecodingConfig(fused_batch_size=-1, strategy=\"greedy_batch\", confidence_cfg=confidence_cfg)\n", + " if is_rnnt\n", + " else CTCDecodingConfig(confidence_cfg=confidence_cfg)\n", + ")\n", + "\n", + "noise_transcriptions = model.transcribe(\n", + " paths2audio_files=noise_data.filepaths, batch_size=4, return_hypotheses=True, num_workers=4\n", + ")\n", + "if is_rnnt:\n", + " noise_transcriptions = noise_transcriptions[0]" + ] + }, + { + "cell_type": "markdown", + "id": "3d097ca6", + "metadata": { + "id": "2f192186" + }, + "source": [ + "On a fully non-speech dataset, hallucinations can be measured as the Word Insertions per Second (WIS) value." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19c6321c", + "metadata": { + "id": "3589da00" + }, + "outputs": [], + "source": [ + "def word_insertions_per_second(texts: List[str], durations: List[float]):\n", + " \"\"\"Calculate the Word Insertions per Second (WIS) value for the given recognition results \n", + " and their corresponding audio duration.\n", + " \"\"\"\n", + " assert len(texts) == len(durations)\n", + "\n", + " wis_per_utt = [len(text.split(\" \")) / duration for text, duration in zip(texts, durations)]\n", + " return sum(wis_per_utt) / len(wis_per_utt), wis_per_utt\n", + "\n", + "wis, wis_per_utt = word_insertions_per_second([t.text for t in noise_transcriptions], noise_data.durations)\n", + "print(f\"Original Word Insertions per Second: {wis:.5f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "bcf44daf", + "metadata": { + "id": "a0d8135d" + }, + "source": [ + "Now, the ability of a confidence estimator to detect hallucinations is computed as the Hallucination Detection Rate (HDR).\n", + "\n", + "It shows how many of all hallucinations can be removed, provided that no more than some fixed percentage of correct words are erroneously removed (under normal recognition conditions).\n", + "\n", + "HDR is another name of the metric $\\mathrm{TNR}_{FNR=e}$ which is calculated as $\\mathrm{TNR}(Y,\\tau): \\mathrm{FNR}(X,\\tau) \\approx e$, where $X$ is the dataset with supervision (to tune $\\tau$) and $Y$ is the noise-only dataset. Typical $e$ value is 0.05.\n", + "\n", + "Let's compute HDR and the new WIS.\n", + "\n", + "The generated dataset is clearly distinct from speech, so $e=0.01$ is sufficient." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3dac1f7d", + "metadata": { + "id": "0612ccf6" + }, + "outputs": [], + "source": [ + "def hdr(\n", + " y_true_speech: Union[List[int], np.ndarray],\n", + " y_score_speech: Union[List[float], np.ndarray],\n", + " y_score_noise: Union[List[float], np.ndarray],\n", + " max_fnr: float = 0.05,\n", + " n_bins: int = 100,\n", + ") -> Tuple[float, float]:\n", + " \"\"\"Compute Hallucination Detection Rate (HDR) from prediction scores.\n", + "\n", + " Returns:\n", + " tnr: True-Negateve Rate for HDR\n", + " threshold_hdr: Optomal threshold \n", + " \"\"\"\n", + " y_true_speech = np.array(y_true_speech)\n", + " y_score_speech = np.array(y_score_speech)\n", + " y_score_noise = np.array(y_score_noise)\n", + " thresholds = np.linspace(0, 1, n_bins + 1)\n", + " assert y_true_speech.shape[0] == y_score_speech.shape[0]\n", + " assert np.all(y_true_speech >= 0) and np.all(y_true_speech <= 1)\n", + " if np.all(y_true_speech == 0) or np.all(y_true_speech == 1):\n", + " return 0.0, 0.0\n", + " mask_correct = y_true_speech == 1\n", + " count_correct = max(mask_correct.nonzero()[0].shape[0], 1)\n", + " y_score_correct = y_score_speech[mask_correct]\n", + " threshold_hdr = 0.0\n", + " for threshold in thresholds:\n", + " fnr = (y_score_correct < threshold).nonzero()[0].shape[0] / count_correct\n", + " if fnr <= max_fnr:\n", + " threshold_hdr = threshold\n", + " else:\n", + " break\n", + " tnr = (y_score_noise < threshold_hdr).nonzero()[0].shape[0] / y_score_noise.shape[0]\n", + " return tnr, threshold_hdr\n", + "\n", + "\n", + "# e\n", + "max_fnr = 0.01\n", + "\n", + "correct_marks = [\n", + " mark for r, h in zip(current_test_set.reference_texts, transcriptions) for mark in get_correct_marks(r.split(), h.words)\n", + "]\n", + "y_score_speech = [w for h in transcriptions for w in h.word_confidence]\n", + "y_score_noise = [w for h in noise_transcriptions for w in h.word_confidence]\n", + "hdr_score, threshold_hdr = hdr(correct_marks, y_score_speech, y_score_noise, max_fnr=max_fnr)\n", + "wis_new = wis - wis * hdr_score\n", + "\n", + "hdr_score, wis_new\n", + "print(\n", + " f\"\"\" Hallucination Detection Rate for max_fnr={max_fnr} : {hdr_score:.5f}\n", + " New Word Insertions Per Second: {wis_new:.5f}\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "443938bc", + "metadata": { + "id": "418297d6" + }, + "source": [ + "Finally, let's print the noisy utterances to see if any more hallucinations persist." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dde9e7db", + "metadata": { + "id": "3815e8e3" + }, + "outputs": [], + "source": [ + "sorted_wis_indices = sorted(enumerate(wis_per_utt), key=lambda x: x[1])[::-1]\n", + "\n", + "show_dataset_with_confidence(\n", + " indices=sorted_wis_indices,\n", + " transcriptions=noise_transcriptions,\n", + " test_set=noise_data,\n", + " threshold=threshold_hdr,\n", + " filepaths=noise_data.filepaths,\n", + " html_show=is_colab,\n", + " min_dur_to_show=0.0,\n", + " utt_to_show=5,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "66f92938", + "metadata": { + "id": "0ac58ef2" + }, + "source": [ + "# Summary\n", + "This tutorial covered the basics of ASR confidence estimation and two examples of using ASR word confidence: WER reduction and hallusinations removal.\n", + "\n", + "You can follow this tutorial on [ASR Confidence-based Ensembles](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/Confidence_Ensembles.ipynb) to see another important application of ASR confidence estimation." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/asr/ASR_TTS_Tutorial.ipynb b/tutorials/asr/ASR_TTS_Tutorial.ipynb new file mode 100644 index 000000000000..007713ee3cc2 --- /dev/null +++ b/tutorials/asr/ASR_TTS_Tutorial.ipynb @@ -0,0 +1,846 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a3570803-9bfa-4e97-9891-5ae0759eb8ca", + "metadata": {}, + "source": [ + "# Hybrid ASR-TTS Models Tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "50fc294f-f319-4465-8f90-a28b49843e60", + "metadata": {}, + "source": [ + "This tutorial is intended to introduce you to using ASR-TTS Hybrid Models, also known as `ASRWithTTSModel`, to finetune existing ASR models using an integrated text-to-mel-spectrogram generator. " + ] + }, + { + "cell_type": "markdown", + "id": "d2a01ca5-bd48-4d82-a97d-5b07a7b27ca0", + "metadata": {}, + "source": [ + "## ASR-TTS Models: Description" + ] + }, + { + "cell_type": "markdown", + "id": "b32467a9-c458-4590-aff7-e8d1e91b0870", + "metadata": {}, + "source": [ + "### Problem\n", + "\n", + "Adapting ASR models to a new text domain is a challenging task. Modern end-to-end systems can require several hundreds and thousands of hours to perform recognition with high accuracy. Acquiring audio-text paired data for a specific domain can be prohibitively expensive. Text-only data, on the other side, is widely available. \n", + "\n", + "One of the approaches for efficient adaptation is synthesizing audio data from text and using such data for training the ASR model conventionally. We modify this approach, incorporating TTS and ASR systems into a single model. We use only a lightweight multi-speaker text-to-mel-spectrogram generator (without vocoder) with an optional enhancer that mitigates the mismatch between natural and synthetic spectrograms.\n", + "\n", + "### Architecture\n", + "\n", + "\"ASR-TTS\n", + "\n", + "`ASRWithTTSModel` is a transparent wrapper for three models:\n", + "- ASR model (`EncDecCTCModelBPE`, `EncDecRNNTBPEModel` or `EncDecHybridRNNTCTCBPEModel` are supported)\n", + "- frozen text-to-mel-spectrogram model (currently, only `FastPitch` model is supported)\n", + "- optional frozen enhancer model\n", + "\n", + "The architecture is shown in the figure. \n", + "\n", + "The model can take text or audio as input during training. In the case of audio input, a mel spectrogram is extracted as usual and passed to the ASR neural network. In the case of textual input, the mel spectrogram generator produces a spectrogram on the fly from the text. The spectrogram is improved by the enhancer (if present) and fed into the ASR model. \n", + "\n", + "### Capabilities and Limitations\n", + "\n", + "This approach can be used to finetune the pretrained ASR model using text-only data. Training new models from scratch is also possible. The text should contain phrases and sentences and be split into sentences (~45 words maximum, corresponding to ~16.7 seconds of synthesized audio). Using only separate words is not recommended since this doesn't allow to adapt ASR model adapts to recognize new words in context. \n", + "\n", + "Mixing audio-text pairs with text-only data from the original domain is recommended to preserve performance on the original data. \n", + "Also, fusing BatchNorm (see parameters below) is recommended for the best performance when using a large proportion of text compared to the amount of audio-text pairs in finetuning process.\n", + "\n", + "\n", + "### Implementation Details and Experiments\n", + "\n", + "Further details about implementation and experiments can be found in the paper [Text-only domain adaptation for end-to-end ASR using integrated text-to-mel-spectrogram generator](https://arxiv.org/abs/2302.14036)\n" + ] + }, + { + "cell_type": "markdown", + "id": "2702d081-c675-4a96-8263-6059e310d048", + "metadata": {}, + "source": [ + "## Example: Finetuning ASR Model Using Text-Only Data" + ] + }, + { + "cell_type": "markdown", + "id": "30fe41a3-f36c-4803-a7f0-4260fb111478", + "metadata": {}, + "source": [ + "In this example, we will finetune a pretrained small Conformer-CTC model using text-only data from the AN4 dataset. [AN4 dataset](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/datasets.html#an4-dataset) is a small dataset that consists of sentences of people spelling out addresses, names, and other entities.\n", + "\n", + "The model is pretrained on LibriSpeech data and performs poorly on AN4 data (`~17.7%` WER on test data).\n", + "We will use only text from the train part to construct text-only training data for our model and will achieve a good performance on the test part of the AN4 dataset (`~2%` WER)." + ] + }, + { + "cell_type": "markdown", + "id": "923819bb-7822-412a-8f9b-98c76c70e0bb", + "metadata": {}, + "source": [ + "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run the following cell to set up dependencies.\n", + "\n", + "NOTE: The user is responsible for checking the content of datasets and the applicable licenses and determining if they are suitable for the intended use." + ] + }, + { + "cell_type": "markdown", + "id": "4685a9da-b3f8-4b95-ba74-64a114223233", + "metadata": {}, + "source": [ + "### Install Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d22d241-6c46-492c-99db-3bd69777243c", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " import google.colab\n", + "\n", + " IN_COLAB = True\n", + "except (ImportError, ModuleNotFoundError):\n", + " IN_COLAB = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc38a961-8822-4685-89ae-ab6f591f9c28", + "metadata": {}, + "outputs": [], + "source": [ + "BRANCH = 'main'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd60b1c4-7b1d-421d-9d63-95d7458bbcbd", + "metadata": {}, + "outputs": [], + "source": [ + "# If you're using Google Colab and not running locally, run this cell.\n", + "\n", + "if IN_COLAB:\n", + " ## Install dependencies\n", + " !pip install wget\n", + " !apt-get install sox libsndfile1 ffmpeg\n", + " !pip install text-unidecode\n", + "\n", + " ## Install NeMo\n", + " !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]" + ] + }, + { + "cell_type": "markdown", + "id": "08f99618-6f83-44b3-bc8e-f7df04fc471c", + "metadata": {}, + "source": [ + "### Import necessary libraries and utils" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74f780b1-9b72-4acf-bcf0-64e1ce84e76d", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "import string\n", + "import tempfile\n", + "\n", + "from omegaconf import OmegaConf\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "from tqdm.auto import tqdm\n", + "import wget\n", + "\n", + "from nemo.collections.asr.models import EncDecCTCModelBPE\n", + "from nemo.collections.asr.models.hybrid_asr_tts_models import ASRWithTTSModel\n", + "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest\n", + "from nemo.collections.tts.models import FastPitchModel, SpectrogramEnhancerModel\n", + "from nemo.utils.notebook_utils import download_an4\n", + "\n", + "from nemo_text_processing.text_normalization.normalize import Normalizer" + ] + }, + { + "cell_type": "markdown", + "id": "ca928d36-fb0d-439b-bac0-299e98a72d02", + "metadata": {}, + "source": [ + "### Prepare Data" + ] + }, + { + "cell_type": "markdown", + "id": "702e8e92-17b2-4f34-a2d9-c72b94501bf5", + "metadata": {}, + "source": [ + "Download and preprocess AN4 data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62c7cfec-aa98-4fc5-8b31-23ee1d59f311", + "metadata": {}, + "outputs": [], + "source": [ + "DATASETS_DIR = Path(\"./datasets\") # directory for data\n", + "CHECKPOINTS_DIR = Path(\"./checkpoints/\") # directory for checkpoints" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "659db73e-dcd7-455c-8140-20e104d6ac00", + "metadata": {}, + "outputs": [], + "source": [ + "# create directories if necessary\n", + "DATASETS_DIR.mkdir(parents=True, exist_ok=True)\n", + "CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36830e7f-5293-4401-8c56-780127b47385", + "metadata": {}, + "outputs": [], + "source": [ + "download_an4(data_dir=f\"{DATASETS_DIR}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e77f5062-9acb-4f39-b811-a5b11dd6f76f", + "metadata": {}, + "outputs": [], + "source": [ + "AN4_DATASET = DATASETS_DIR / \"an4\"" + ] + }, + { + "cell_type": "markdown", + "id": "403b63b0-8aab-43aa-a455-31f588d1772f", + "metadata": {}, + "source": [ + "### Construct text-only training data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35654ee1-3869-4289-bd52-15818c0ccf69", + "metadata": {}, + "outputs": [], + "source": [ + "# read original training data\n", + "an4_train_data = read_manifest(AN4_DATASET / \"train_manifest.json\")" + ] + }, + { + "cell_type": "markdown", + "id": "a17f583c-2a5c-4faf-84bd-eb04c2921e01", + "metadata": {}, + "source": [ + "Text-only manifest should contain three fields:\n", + "- `text`: target text for the ASR model\n", + "- `tts_text`: text to use as a source for the TTS model (unnormalized)\n", + "- `tts_text_normalized`: text to use as a source for TTS model (normalized)\n", + "\n", + "If `tts_text_normalized` is not present, `tts_text` will be used, and normalization will be done when loading the dataset.\n", + "It is highly recommended to normalize the text and manually create the `tts_text_normalized` field since current normalizers are unsuitable for processing a large amount of text on the fly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5938a8c2-e239-4a45-a716-dc11a981aec7", + "metadata": {}, + "outputs": [], + "source": [ + "# fill `text` and `tts_text` fields with the source data\n", + "textonly_data = []\n", + "for record in an4_train_data:\n", + " text = record[\"text\"]\n", + " textonly_data.append({\"text\": text, \"tts_text\": text})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f6a5735-a5c2-4a8b-8116-bfc535a2c299", + "metadata": {}, + "outputs": [], + "source": [ + "WHITELIST_URL = (\n", + " \"https://raw.githubusercontent.com/NVIDIA/NeMo-text-processing/main/\"\n", + " \"nemo_text_processing/text_normalization/en/data/whitelist/lj_speech.tsv\"\n", + ")\n", + "\n", + "\n", + "def get_normalizer() -> Normalizer:\n", + " with tempfile.TemporaryDirectory() as data_dir:\n", + " whitelist_path = Path(data_dir) / \"lj_speech.tsv\"\n", + " if not whitelist_path.exists():\n", + " wget.download(WHITELIST_URL, out=str(data_dir))\n", + "\n", + " normalizer = Normalizer(\n", + " lang=\"en\",\n", + " input_case=\"cased\",\n", + " whitelist=str(whitelist_path),\n", + " overwrite_cache=True,\n", + " cache_dir=None,\n", + " )\n", + " return normalizer" + ] + }, + { + "cell_type": "markdown", + "id": "dd0253aa-d7f1-47ee-a142-099b71241270", + "metadata": {}, + "source": [ + "Сonstruct the `tts_text_normalized` field by applying an English normalizer to the text.\n", + "\n", + "AN4 data doesn't contain numbers, currency, and other entities, so the normalizer is used here only for demonstration purposes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27bb29d5-d44d-4026-98f8-5f0b1241b39a", + "metadata": {}, + "outputs": [], + "source": [ + "normalizer = get_normalizer()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9400e6d3-ba92-442a-8dd4-117e95dce2ea", + "metadata": {}, + "outputs": [], + "source": [ + "for record in tqdm(textonly_data):\n", + " record[\"tts_text_normalized\"] = normalizer.normalize(\n", + " record[\"tts_text\"], verbose=False, punct_pre_process=True, punct_post_process=True\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "30a934b0-9b58-4bad-bb9a-ab78d81c3859", + "metadata": {}, + "source": [ + "Save manifest" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1833ac15-1750-4468-88bc-2343fbabe4d8", + "metadata": {}, + "outputs": [], + "source": [ + "write_manifest(AN4_DATASET / \"train_text_manifest.json\", textonly_data)" + ] + }, + { + "cell_type": "markdown", + "id": "fa3a2371-8c78-4dd1-9605-a668adf52b4a", + "metadata": {}, + "source": [ + "### Save pretrained checkpoints" + ] + }, + { + "cell_type": "markdown", + "id": "7eb14117-8b8b-4170-ab8c-ce496522a361", + "metadata": {}, + "source": [ + "Firstly we will load pretrained models from NGC and save them as `nemo` checkpoints. \n", + "Our hybrid model will be constructed from these checkpoints.\n", + "We will use:\n", + "- small Conformer-CTC ASR model trained on LibriSpeech data (for finetuning)\n", + "- multi-speaker TTS FastPitch model is trained on LibriTTS data. Spectrogram parameters for this model are the same as those used in the ASR model\n", + "- enhancer, which is trained adversarially on the output of the TTS model and natural spectrograms" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43c5c75a-b6e0-4b3c-ad26-a07b483d84e6", + "metadata": {}, + "outputs": [], + "source": [ + "ASR_MODEL_PATH = CHECKPOINTS_DIR / \"stt_en_conformer_ctc_small_ls.nemo\"\n", + "TTS_MODEL_PATH = CHECKPOINTS_DIR / \"fastpitch.nemo\"\n", + "ENHANCER_MODEL_PATH = CHECKPOINTS_DIR / \"enhancer.nemo\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40976e22-7a7b-42b2-86a1-9eaaef4c1c22", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# asr model: stt_en_conformer_ctc_small_ls\n", + "asr_model = EncDecCTCModelBPE.from_pretrained(model_name=\"stt_en_conformer_ctc_small_ls\")\n", + "asr_model.save_to(f\"{ASR_MODEL_PATH}\")\n", + "\n", + "# tts model: tts_en_fastpitch_for_asr_finetuning\n", + "tts_model = FastPitchModel.from_pretrained(model_name=\"tts_en_fastpitch_for_asr_finetuning\")\n", + "tts_model.save_to(f\"{TTS_MODEL_PATH}\")\n", + "\n", + "# enhancer model: tts_en_spectrogram_enhancer_for_asr_finetuning\n", + "enhancer_model = SpectrogramEnhancerModel.from_pretrained(model_name=\"tts_en_spectrogram_enhancer_for_asr_finetuning\")\n", + "enhancer_model.save_to(f\"{ENHANCER_MODEL_PATH}\")" + ] + }, + { + "cell_type": "markdown", + "id": "32d1e242-0ab0-43bf-aaa0-997d284c2c1b", + "metadata": {}, + "source": [ + "### Construct hybrid ASR-TTS model " + ] + }, + { + "cell_type": "markdown", + "id": "2210eb07-6d44-44e0-a0ad-866f1e89873a", + "metadata": {}, + "source": [ + "#### Config Parameters\n", + "\n", + "`Hybrid ASR-TTS model` consists of three parts:\n", + "\n", + "* ASR model (``EncDecCTCModelBPE``, ``EncDecRNNTBPEModel`` or ``EncDecHybridRNNTCTCBPEModel``)\n", + "* TTS Mel Spectrogram Generator (currently, only `FastPitch` model is supported)\n", + "* Enhancer model (optional)\n", + "\n", + "Also, the config allows to specify a text-only dataset.\n", + "\n", + "Main parts of the config:\n", + "\n", + "* ASR model\n", + " * ``asr_model_path``: path to the ASR model checkpoint (`.nemo`) file, loaded only once, then the config of the ASR model is stored in the ``asr_model`` field\n", + " * ``asr_model_type``: needed only when training from scratch. ``rnnt_bpe`` corresponds to ``EncDecRNNTBPEModel``, ``ctc_bpe`` to ``EncDecCTCModelBPE``, ``hybrid_rnnt_ctc_bpe`` to ``EncDecHybridRNNTCTCBPEModel``\n", + " * ``asr_model_fuse_bn``: fusing BatchNorm in the pretrained ASR model, can improve quality in finetuning scenario\n", + "* TTS model\n", + " * ``tts_model_path``: path to the pretrained TTS model checkpoint (`.nemo`) file, loaded only once, then the config of the model is stored in the ``tts_model`` field\n", + "* Enhancer model\n", + " * ``enhancer_model_path``: optional path to the enhancer model. Loaded only once, the config is stored in the ``enhancer_model`` field\n", + "* ``train_ds``\n", + " * ``text_data``: properties related to text-only data\n", + " * ``manifest_filepath``: path (or paths) to text-only dataset manifests\n", + " * ``speakers_filepath``: path (or paths) to the text file containing speaker ids for the multi-speaker TTS model (speakers are sampled randomly during training)\n", + " * ``min_words`` and ``max_words``: parameters to filter text-only manifests by the number of words\n", + " * ``tokenizer_workers``: number of workers for initial tokenization (when loading the data). ``num_CPUs / num_GPUs`` is a recommended value.\n", + " * ``asr_tts_sampling_technique``, ``asr_tts_sampling_temperature``, ``asr_tts_sampling_probabilities``: sampling parameters for text-only and audio-text data (if both specified). Correspond to ``sampling_technique``, ``sampling_temperature``, and ``sampling_probabilities`` parameters of the `nemo.collections.common.data.dataset.ConcatDataset`.\n", + " * all other components are similar to conventional ASR models\n", + "* ``validation_ds`` and ``test_ds`` correspond to the underlying ASR model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d6dd499-d388-4ee3-9a01-d739b16e6ad7", + "metadata": {}, + "outputs": [], + "source": [ + "# load config\n", + "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/asr_tts/hybrid_asr_tts.yaml" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6701dc8-cb3b-44cc-aab5-fb6e2c1dadb5", + "metadata": {}, + "outputs": [], + "source": [ + "config = OmegaConf.load(\"./configs/hybrid_asr_tts.yaml\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c13b3c96-4074-415f-95d2-17569886bfcd", + "metadata": {}, + "outputs": [], + "source": [ + "NUM_EPOCHS = 10" + ] + }, + { + "cell_type": "markdown", + "id": "4d090c5d-44a7-401a-a753-b8779b1c1e0b", + "metadata": {}, + "source": [ + "We will use all available speakers (sampled uniformly)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c41e5e8-d926-4b83-8725-bae5a82121cf", + "metadata": {}, + "outputs": [], + "source": [ + "TTS_SPEAKERS_PATH = Path(\"./checkpoints/speakers.txt\")\n", + "\n", + "with open(TTS_SPEAKERS_PATH, \"w\", encoding=\"utf-8\") as f:\n", + " for speaker_id in range(tts_model.cfg.n_speakers):\n", + " print(speaker_id, file=f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c07c07c-cb15-4a1c-80bf-20eaffaa65d9", + "metadata": {}, + "outputs": [], + "source": [ + "config.model.asr_model_path = ASR_MODEL_PATH\n", + "config.model.tts_model_path = TTS_MODEL_PATH\n", + "config.model.enhancer_model_path = ENHANCER_MODEL_PATH\n", + "\n", + "# fuse BathNorm automatically in Conformer for better performance\n", + "config.model.asr_model_fuse_bn = True\n", + "\n", + "# training data\n", + "# constructed dataset\n", + "config.model.train_ds.text_data.manifest_filepath = str(AN4_DATASET / \"train_text_manifest.json\")\n", + "# speakers for TTS model\n", + "config.model.train_ds.text_data.speakers_filepath = f\"{TTS_SPEAKERS_PATH}\"\n", + "config.model.train_ds.manifest_filepath = None # audio-text pairs - we don't use them here\n", + "config.model.train_ds.batch_size = 8\n", + "\n", + "# validation data\n", + "config.model.validation_ds.manifest_filepath = str(AN4_DATASET / \"test_manifest.json\")\n", + "config.model.validation_ds.batch_size = 8\n", + "\n", + "config.trainer.max_epochs = NUM_EPOCHS\n", + "\n", + "config.trainer.devices = 1\n", + "config.trainer.strategy = None # use 1 device, no need for ddp strategy\n", + "\n", + "OmegaConf.resolve(config)" + ] + }, + { + "cell_type": "markdown", + "id": "8ae6cb2e-f571-4b53-8897-bb8ba0fc1146", + "metadata": {}, + "source": [ + "#### Construct trainer and ASRWithTTSModel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac4ae885-dec4-4ce9-8f69-a1f35d04b08c", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(**config.trainer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f815762-b08d-4d3c-8fd3-61afa511eab4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "hybrid_model = ASRWithTTSModel(config.model)" + ] + }, + { + "cell_type": "markdown", + "id": "ca2c1bf2-28a9-4902-9c73-d96e04b21a46", + "metadata": {}, + "source": [ + "#### Validate the model\n", + "\n", + "Expect `~17.7%` WER on the AN4 test data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffa5f5c6-0609-4f46-aa0c-747319035417", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.validate(hybrid_model)" + ] + }, + { + "cell_type": "markdown", + "id": "701ee9c7-91a1-4917-bf7d-ab26b625c7bf", + "metadata": {}, + "source": [ + "#### Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f79761c9-b882-4f14-911f-4a960ff81554", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "trainer.fit(hybrid_model)" + ] + }, + { + "cell_type": "markdown", + "id": "eac18c7c-bdcb-40ad-9c50-37f89fb4aa2a", + "metadata": {}, + "source": [ + "#### Validate the model after training\n", + "\n", + "Expect `~2%` WER on the AN4 test data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd927e87-13fb-4b61-8b4a-a6850780f605", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "trainer.validate(hybrid_model)" + ] + }, + { + "cell_type": "markdown", + "id": "6d25a77d-35ed-44b5-9ef5-318afa321acf", + "metadata": {}, + "source": [ + "### Save final model. Extract pure ASR model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f53ebd3-b89a-47e4-a0a5-ed3a3572f7c1", + "metadata": {}, + "outputs": [], + "source": [ + "# save full model: the model can be further used for finetuning\n", + "hybrid_model.save_to(\"checkpoints/finetuned_hybrid_model.nemo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0560c2c-af28-4d8f-b36d-c18ec6a482a8", + "metadata": {}, + "outputs": [], + "source": [ + "# extract the resulting ASR model from the hybrid model\n", + "hybrid_model.save_asr_model_to(\"checkpoints/finetuned_asr_model.nemo\")" + ] + }, + { + "cell_type": "markdown", + "id": "2de58fbb-50be-42cd-9095-01cacfdb6931", + "metadata": {}, + "source": [ + "## Using Scripts (examples)" + ] + }, + { + "cell_type": "markdown", + "id": "86655198-b1fc-4615-958c-7c01f3cbd024", + "metadata": {}, + "source": [ + "`/examples/asr/asr_with_tts/` contains scripts for finetuning existing models and training new models from scratch." + ] + }, + { + "cell_type": "markdown", + "id": "b5837536-8280-475c-a581-caaee00edfca", + "metadata": {}, + "source": [ + "### Finetuning Existing Model" + ] + }, + { + "cell_type": "markdown", + "id": "84df9aeb-3b5e-41fc-a8d0-dfc660e71375", + "metadata": {}, + "source": [ + "To finetune existing ASR model using text-only data use `/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py` script with the corresponding config `/examples/asr/conf/asr_tts/hybrid_asr_tts.yaml`.\n", + "\n", + "Please specify paths to all the required models (ASR, TTS, and Enhancer checkpoints), along with `train_ds.text_data.manifest_filepath` and `train_ds.text_data.speakers_filepath`." + ] + }, + { + "cell_type": "markdown", + "id": "78b9028c-02ce-4af4-b510-a431f4a2f62b", + "metadata": {}, + "source": [ + "```shell\n", + "python speech_to_text_bpe_with_text_finetune.py \\\n", + " model.asr_model_path= \\\n", + " model.tts_model_path= \\\n", + " model.enhancer_model_path= \\\n", + " model.asr_model_fuse_bn= \\\n", + " model.train_ds.manifest_filepath= \\\n", + " model.train_ds.text_data.manifest_filepath= \\\n", + " model.train_ds.text_data.speakers_filepath= \\\n", + " model.train_ds.text_data.tokenizer_workers=4 \\\n", + " model.validation_ds.manifest_filepath= \\\n", + " model.train_ds.batch_size=\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "0b17c097-a3b1-49a3-8f54-f07b94218d0b", + "metadata": {}, + "source": [ + "### Training a New Model from Scratch" + ] + }, + { + "cell_type": "markdown", + "id": "6d75b928-57b3-4180-bd09-37e018eef7ef", + "metadata": {}, + "source": [ + "```shell\n", + "python speech_to_text_bpe_with_text.py \\\n", + " # (Optional: --config-path= --config-name=) \\\n", + " ++asr_model_type= \\\n", + " ++tts_model_path= \\\n", + " ++enhancer_model_path= \\\n", + " model.tokenizer.dir= \\\n", + " model.tokenizer.type=\"bpe\" \\\n", + " model.train_ds.manifest_filepath= \\\n", + " ++model.train_ds.text_data.manifest_filepath= \\\n", + " ++model.train_ds.text_data.speakers_filepath= \\\n", + " ++model.train_ds.text_data.min_words=1 \\\n", + " ++model.train_ds.text_data.max_words=45 \\\n", + " ++model.train_ds.text_data.tokenizer_workers=4 \\\n", + " model.validation_ds.manifest_filepath= \\\n", + " model.train_ds.batch_size= \\\n", + " trainer.max_epochs= \\\n", + " trainer.num_nodes= \\\n", + " trainer.accumulate_grad_batches= \\\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "01c17712-ae8d-49cb-ade1-ded168676e27", + "metadata": {}, + "source": [ + "## Training TTS Models for ASR Finetuning" + ] + }, + { + "cell_type": "markdown", + "id": "422dc3b2-d29f-4ed0-b4d2-6d32b35dfb7b", + "metadata": {}, + "source": [ + "### TTS Model (FastPitch)\n", + "\n", + "TTS model for the purpose of ASR model finetuning should be trained with the same mel spectrogram parameters as used in the ASR model. The typical parameters are `10ms` hop length, `25ms` window length, and the highest band of 8kHz (for 16kHz data). Other parameters are the same as for common multi-speaker TTS models.\n", + "\n", + "Mainly we observed two differences specific to TTS models for ASR:\n", + "- adding more speakers and more data improves the final ASR model quality (but not the perceptual quality of the TTS model)\n", + "- training for more epochs can also improve the quality of the ASR system (but MSE loss used for the TTS model can be higher than optimal on validation data)\n", + "\n", + "Use script `/examples/tts/fastpitch.py` to train a FastPitch model.\n", + "More details about the FastPitch model can be found in the [documentation](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/tts/models.html#fastpitch). \n", + "\n", + "### Enhancer\n", + "Use script `/examples/tts/spectrogram_enhancer.py` to train an Enhancer model. More details can be found in the \n", + "[documentation](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/tts/models.html).\n", + "\n", + "### Models Used in This Tutorial\n", + "\n", + "Some details about the models used in this tutorial can be found on [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/tts_en_fastpitch_spectrogram_enhancer_for_asr_finetuning).\n", + "\n", + "The system is also described in detail in the paper in the paper [Text-only domain adaptation for end-to-end ASR using integrated text-to-mel-spectrogram generator](https://arxiv.org/abs/2302.14036)." + ] + }, + { + "cell_type": "markdown", + "id": "9a9a6cd3-4bdc-4b6e-b4b1-3bfd50fd01b3", + "metadata": {}, + "source": [ + "## Summary" + ] + }, + { + "cell_type": "markdown", + "id": "e2890c61-e4b7-47aa-a086-bc483ae7141f", + "metadata": {}, + "source": [ + "The tutorial demonstrated the main concepts related to hybrid ASR-TTS models to finetune ASR models and train new ones from scratch. \n", + "The ability to achieve good text-only adaptation results is demonstrated by finetuning a small Conformer model on text-only data from the AN4 dataset." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ml38", + "language": "python", + "name": "ml38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/asr/Confidence_Ensembles.ipynb b/tutorials/asr/Confidence_Ensembles.ipynb index f9617c75e36a..4516d2b70d6d 100644 --- a/tutorials/asr/Confidence_Ensembles.ipynb +++ b/tutorials/asr/Confidence_Ensembles.ipynb @@ -110,7 +110,7 @@ "\n", "### How to estimate a model's confidence?\n", "\n", - "Good news, we have a whole separate [tutorial](TBD) on this topic! You can go through it if you want to know all the details about different ways to estimate confidence of NeMo ASR models. There are different confidence measures and aggregation functions and for the absolute best performance, you will need to run a grid-search to pick the best confidence estimation way for your specific models and data.\n", + "Good news, we have a whole separate [tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/ASR_Confidence_Estimation.ipynb) on this topic! You can go through it if you want to know all the details about different ways to estimate confidence of NeMo ASR models. There are different confidence measures and aggregation functions and for the absolute best performance, you will need to run a grid-search to pick the best confidence estimation way for your specific models and data.\n", "\n", "That being said, we found that there exist a set of confidence parameters that work pretty well on a large set of models and datsets. They are default in NeMo and so you might not need to worry about running the search. If you do want to maximize the performance by tuning the confidence parameters, you only need to add [a few extra config lines](#Building-and-evaluating-ensemble-(tuned-parameters)).\n", "\n", diff --git a/tutorials/asr/Offline_ASR_with_VAD_for_CTC_models.ipynb b/tutorials/asr/Offline_ASR_with_VAD_for_CTC_models.ipynb index 7e9d0378bc1f..8a8335ac1542 100644 --- a/tutorials/asr/Offline_ASR_with_VAD_for_CTC_models.ipynb +++ b/tutorials/asr/Offline_ASR_with_VAD_for_CTC_models.ipynb @@ -389,7 +389,7 @@ "source": [ "# Further Reading\n", "\n", - "There are two ways to incorporate VAD into ASR pipeline. The first strategy is to drop the frames that are predicted as `non-speech` by VAD, as already discussed in this tutorial. The second strategy is to keep all the frames and mask the `non-speech` frames with zero-signal values. Also, instead of using segment-VAD as shown in this tutorial, we can use frame-VAD model for faster inference and better accuracy. For more information, please refer to the script [speech_to_text_with_vad.py](https://github.com/NVIDIA/NeMo/blob/stable/examples/asr_vad/speech_to_text_with_vad.py)." + "There are two ways to incorporate VAD into ASR pipeline. The first strategy is to drop the frames that are predicted as `non-speech` by VAD, as already discussed in this tutorial. The second strategy is to keep all the frames and mask the `non-speech` frames with zero-signal values. Also, instead of using segment-VAD as shown in this tutorial, we can use frame-VAD model for faster inference and better accuracy. For more information, please refer to the script [speech_to_text_with_vad.py](https://github.com/NVIDIA/NeMo/blob/stable/examples/asr/asr_vad/speech_to_text_with_vad.py)." ] } ], diff --git a/tutorials/nlp/SpellMapper_English_ASR_Customization.ipynb b/tutorials/nlp/SpellMapper_English_ASR_Customization.ipynb index cc949ad699b3..1be4704cc13c 100644 --- a/tutorials/nlp/SpellMapper_English_ASR_Customization.ipynb +++ b/tutorials/nlp/SpellMapper_English_ASR_Customization.ipynb @@ -85,7 +85,7 @@ "# Install NeMo library. If you are running locally (rather than on Google Colab), comment out the below lines\n", "# and instead follow the instructions at https://github.com/NVIDIA/NeMo#Installation\n", "GITHUB_ACCOUNT = \"NVIDIA\"\n", - "BRANCH = \"main\"\n", + "BRANCH = 'main'\n", "!python -m pip install git+https://github.com/{GITHUB_ACCOUNT}/NeMo.git@{BRANCH}#egg=nemo_toolkit[all]\n", "\n", "# Download local version of NeMo scripts. If you are running locally and want to use your own local NeMo code,\n", @@ -536,7 +536,7 @@ "id": "b1K6paeee2Iu" }, "source": [ - "As we mentioned earlier, this model pipeline is intended to work with custom vocabularies up to several thousand entries. Since the whole medical vocabulary contains 110k entries, we restrict our custom vocabulary to 5000+ terms that occured in given corpus of abstracts.\n", + "As we mentioned earlier, this model pipeline is intended to work with custom vocabularies up to several thousand entries. Since the whole medical vocabulary contains 110k entries, we restrict our custom vocabulary to 5000+ terms that occurred in given corpus of abstracts.\n", "\n", "The goal of indexing our custom vocabulary is to build an index where key is a letter n-gram and value is the whole phrase. The keys are n-grams in the given user phrase and their misspelled variants taken from our collection of n-\n", "gram mappings (see Index of custom vocabulary in Fig. 1)\n", @@ -1273,7 +1273,7 @@ "### Filtering by Dynamic Programming(DP) score\n", "\n", "What else can be done?\n", - "Given a fragment and its potential replacement, we can apply **dynamic programming** to find the most probable \"translation\" path between them. We will use the same n-gram mapping vocabulary, because its frequencies give us \"translation probability\" of each n-gram pair. The final path score can be calculated as maximum sum of log probalities of matching n-grams along this path.\n", + "Given a fragment and its potential replacement, we can apply **dynamic programming** to find the most probable \"translation\" path between them. We will use the same n-gram mapping vocabulary, because its frequencies give us \"translation probability\" of each n-gram pair. The final path score can be calculated as maximum sum of log probabilities of matching n-grams along this path.\n", "Let's look at an example. " ] },