Skip to content

Commit

Permalink
Adding docs and models for multiple lookahead cache-aware ASR (#7067) (
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] authored and hsiehjackson committed Aug 10, 2023
1 parent 39d477e commit 6d08704
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 19 deletions.
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ Key Features
* Hybrid Transducer/CTC
* NeMo Original `Multi-blank Transducers <https://arxiv.org/abs/2211.03541>`_ and `Token-and-Duration Transducers (TDT) <https://arxiv.org/abs/2304.06795>`_
* Streaming/Buffered ASR (CTC/Transducer) - `Chunked Inference Examples <https://github.com/NVIDIA/NeMo/tree/stable/examples/asr/asr_chunked_inference>`_
* Cache-aware Streaming Conformer - `<https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#cache-aware-streaming-conformer>`_
* Cache-aware Streaming Conformer with multiple lookaheads - `<https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#cache-aware-streaming-conformer>`_
* Beam Search decoding
* `Language Modelling for ASR <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/asr_language_modeling.html>`_: N-gram LM in fusion with Beam Search decoding, Neural Rescoring with Transformer
* `Language Modelling for ASR (CTC and RNNT) <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/asr_language_modeling.html>`_: N-gram LM in fusion with Beam Search decoding, Neural Rescoring with Transformer
* `Support of long audios for Conformer with memory efficient local attention <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/results.html#inference-on-long-audio>`_
* `Speech Classification, Speech Command Recognition and Language Identification <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/speech_classification/intro.html>`_: MatchboxNet (Command Recognition), AmberNet (LangID)
* `Voice activity Detection (VAD) <https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/speech_classification/models.html#marblenet-vad>`_: MarbleNet
Expand Down
3 changes: 2 additions & 1 deletion docs/source/asr/data/benchmark_en.csv
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ stt_en_fastconformer_ctc_xlarge,EncDecCTCModelBPE,"https://ngc.nvidia.com/catalo
stt_en_fastconformer_transducer_xxlarge,EncDecRNNTBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_transducer_xxlarge"
stt_en_fastconformer_hybrid_large_streaming_80ms,EncDecHybridRNNTCTCBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_80ms"
stt_en_fastconformer_hybrid_large_streaming_480ms,EncDecHybridRNNTCTCBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_480ms"
stt_en_fastconformer_hybrid_large_streaming_1040ms,EncDecHybridRNNTCTCBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_1040ms"
stt_en_fastconformer_hybrid_large_streaming_1040ms,EncDecHybridRNNTCTCBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_1040ms"
stt_en_fastconformer_hybrid_large_streaming_multi,EncDecHybridRNNTCTCBPEModel,"https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_multi"
8 changes: 8 additions & 0 deletions docs/source/asr/data/scores/en/conformer_en.csv
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@ stt_en_fastconformer_hybrid_large_streaming_1040ms (CTC),en,,,,,2.7 %,6.4 %,,,9.
stt_en_fastconformer_hybrid_large_streaming_80ms (RNNT),en,,,,,2.7 %,6.5 %,,,9.1 %,6.9 %,,,,,,,3.2 %,1.9 %
stt_en_fastconformer_hybrid_large_streaming_480ms (RNNT),en,,,,,2.7 %,6.1 %,,,8.5 %,6.7 %,,,,,,,3.1 %,1.8 %
stt_en_fastconformer_hybrid_large_streaming_1040ms (RNNT),en,,,,,2.3 %,5.5 %,,,8.0 %,6.6 %,,,,,,,2.9 %,1.6 %
stt_en_fastconformer_hybrid_large_streaming_multi (RNNT - 0ms),en,,,,,,7.0 %,,,,,,,,,,,,
stt_en_fastconformer_hybrid_large_streaming_multi (RNNT - 80ms),en,,,,,,6.4 %,,,,,,,,,,,,
stt_en_fastconformer_hybrid_large_streaming_multi (RNNT - 480),en,,,,,,5.7 %,,,,,,,,,,,,
stt_en_fastconformer_hybrid_large_streaming_multi (RNNT - 1040),en,,,,,,5.4 %,,,,,,,,,,,,
stt_en_fastconformer_hybrid_large_streaming_multi (CTC - 0ms),en,,,,,,8.4 %,,,,,,,,,,,,
stt_en_fastconformer_hybrid_large_streaming_multi (CTC - 80ms),en,,,,,,7.8 %,,,,,,,,,,,,
stt_en_fastconformer_hybrid_large_streaming_multi (CTC - 480),en,,,,,,6.7 %,,,,,,,,,,,,
stt_en_fastconformer_hybrid_large_streaming_multi (CTC - 1040),en,,,,,,6.2 %,,,,,,,,,,,,
38 changes: 26 additions & 12 deletions docs/source/asr/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ We support the following three right context modeling:
* fully causal model with zero look-ahead: tokens would not see any future tokens. convolution layers are all causal and right tokens are masked for self-attention.

It gives zero latency but with limited accuracy.
To train such a model, you need to set `encoder.att_context_size=[left_context, 0]` and `encoder.conv_context_size=causal` in the config.
To train such a model, you need to set `model.encoder.att_context_size=[left_context,0]` and `model.encoder.conv_context_size=causal` in the config.

* regular look-ahead: convolutions would be able to see few future frames, and self-attention would also see the same number of future tokens.

Expand All @@ -186,13 +186,11 @@ For example for a model of 17 layers with 4x downsampling and 10ms window shift,

For example, in a model which chunk size of 20 tokens, tokens at the first position of each chunk would see all the next 19 tokens while the last token would see zero future tokens.
This approach is more efficient than regular look-ahead in terms of computations as the activations for most of the look-ahead part would be cached and there is close to zero duplications in the calculations.
In terms of accuracy, this approach gives similar or even better results in term of accuracy than regular look-ahead as each token in each layer have access to more tokens on average. That is why we recommend to use this approach for streaming.

In terms of accuracy, this approach gives similar or even better results in term of accuracy than regular look-ahead as each token in each layer have access to more tokens on average. That is why we recommend to use this approach for streaming. Therefore we recommend to use the chunk-aware for cache-aware models.

** Note: Latencies are based on the assumption that the forward time of the network is zero and it just estimates the time needed after a frame would be available until it is passed through the model.

Approaches with non-zero look-ahead can give significantly better accuracy by sacrificing latency. The latency can get controlled by the left context size. Increasing the right context would help the accuracy to a limit but would increase the compuation time.

Approaches with non-zero look-ahead can give significantly better accuracy by sacrificing latency. The latency can get controlled by the left context size. Increasing the right context would help the accuracy to a limit but would increase the computation time.

In all modes, left context can be controlled by the number of tokens to be visible in the self-attention and the kernel size of the convolutions.
For example, if left context of self-attention in each layer is set to 20 tokens and there are 10 layers of Conformer, then effective left context is 20*10=200 tokens.
Expand All @@ -202,19 +200,35 @@ Left context of convolutions is dependent to the their kernel size while it can
Self-attention left context of around 6 secs would give close result to have unlimited left context. For a model with 4x downsampling and shift window of 10ms in the preprocessor, each token corresponds to 4*10=40ms.

If striding approach is used for downsampling, all the convolutions in downsampling would be fully causal and don't see future tokens.
You may use stacking for downsampling in the streaming models which is significantly faster and uses less memory.
It also does not some of the the limitations with striding and vggnet and you may use any downsampling rate.

You may find the example config files of cache-aware streaming Conformer models at
``<NeMo_git_root>/examples/asr/conf/conformer/streaming/conformer_transducer_bpe_streaming.yaml`` for Transducer variant and
at ``<NeMo_git_root>/examples/asr/conf/conformer/streaming/conformer_ctc_bpe.yaml`` for CTC variant.
* Multiple Look-aheads
We support multiple look-aheads for cahce-aware models. You may specify a list of context sizes for att_context_size.
During the training, different context sizes would be used randomly with the distribution specified by att_context_probs.
For example you may enable multiple look-aheads by setting `model.encoder.att_context_size=[[70,13],[70,6],[70,1],[70,0]]` for the training.
The first item in the list would be the default during test/validation/inference. To switch between different look-aheads, you may use the method `asr_model.encoder.set_default_att_context_size(att_context_size)` or set the att_context_size like the following when using the script `speech_transcribe.py`:

.. code-block:: bash
python [NEMO_GIT_FOLDER]/examples/asr/transcribe_speech.py \
pretrained_name="stt_en_fastconformer_hybrid_large_streaming_multi" \
audio_dir="<DIRECTORY CONTAINING AUDIO FILES>" \
att_context_size=[70,0]
..
You may find the example config files for cache-aware streaming FastConformer models at
``<NeMo_git_root>/examples/asr/conf/fastconformer/cache_aware_streaming/conformer_transducer_bpe_streaming.yaml`` for Transducer variant and
at ``<NeMo_git_root>/examples/asr/conf/conformer/cache_aware_streaming/conformer_ctc_bpe.yaml`` for CTC variant. It is recommended to use FastConformer as they are more than 2X faster in both training and inference than regular Conformer.
The hybrid versions of FastConformer can be found here: ``<NeMo_git_root>/examples/asr/conf/conformer/hybrid_cache_aware_streaming/``

Examples for regular Conformer can be found at
``<NeMo_git_root>/examples/asr/conf/conformer/cache_aware_streaming/conformer_transducer_bpe_streaming.yaml`` for Transducer variant and
at ``<NeMo_git_root>/examples/asr/conf/conformer/cache_aware_streaming/conformer_ctc_bpe.yaml`` for CTC variant.

To simulate cache-aware streaming, you may use the script at ``<NeMo_git_root>/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py``. It can simulate streaming in single stream or multi-stream mode (in batches) for an ASR model.
This script can be used for models trained offline with full-context but the accuracy would not be great unless the chunk size is large enough which would result in high latency.
It is recommended to train a model in streaming model with limited context for this script. More info can be found in the script.

You may find FastConformer variants of cache-aware streaming models under ``<NeMo_git_root>/examples/asr/conf/fastconformer/``.

Note cache-aware streaming models are being exported without caching support by default.
To include caching support, `model.set_export_config({'cache_support' : 'True'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,13 @@
It may result in slightly different outputs from the sub-sampling module compared to offline mode for some techniques like striding and sw_striding.
Enabling it would make it easier to export the model to ONNX.
# Hybrid ASR models
## Hybrid ASR models
For Hybrid ASR models which have two decoders, you may select the decoder by --set_decoder DECODER_TYPE, where DECODER_TYPE can be "ctc" or "rnnt".
If decoder is not set, then the default decoder would be used which is the RNNT decoder for Hybrid ASR models.
## Multi-lookahead models
For models which support multiple lookaheads, the default is the first one in the list of model.encoder.att_context_size. To change it, you may use --att_context_size, for example --att_context_size [70,1].
## Evaluate a model trained with full context for offline mode
Expand All @@ -58,7 +61,7 @@
The accuracy of the model on the borders of chunks would not be very good.
To use a model trained with full context, you need to pass the chunk_size and shift_size arguments.
If shift_size is not passed, chunk_size would be use as the shift_size too.
If shift_size is not passed, chunk_size would be used as the shift_size too.
Also argument online_normalization should be enabled to simulate a realistic streaming.
The following command would simulate cache-aware streaming on a pretrained model from NGC with chunk_size of 100, shift_size of 50 and 2 left chunks as left context.
The chunk_size of 100 would be 100*4*10=4000ms for a model with 4x downsampling and 10ms shift in feature extraction.
Expand Down Expand Up @@ -273,6 +276,13 @@ def main():
help="Selects the decoder for Hybrid ASR models which has both the CTC and RNNT decoder. Supported decoders are ['ctc', 'rnnt']",
)

parser.add_argument(
"--att_context_size",
type=str,
default=None,
help="Sets the att_context_size for the models which support multiple lookaheads",
)

args = parser.parse_args()
if (args.audio_file is None and args.manifest_file is None) or (
args.audio_file is not None and args.manifest_file is not None
Expand All @@ -293,6 +303,12 @@ def main():
else:
raise ValueError("Decoder cannot get changed for non-Hybrid ASR models.")

if args.att_context_size is not None:
if hasattr(asr_model.encoder, "set_default_att_context_size"):
asr_model.encoder.set_default_att_context_size(att_context_size=json.loads(args.att_context_size))
else:
raise ValueError("Model does not support multiple lookaheads.")

global autocast
if (
args.use_amp
Expand Down
5 changes: 5 additions & 0 deletions examples/asr/speech_to_text_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ class EvaluationConfig(transcribe_speech.TranscriptionConfig):
dataset_manifest: str = MISSING
output_filename: Optional[str] = "evaluation_transcripts.json"

# decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Joint RNNT/CTC models
decoder_type: Optional[str] = None
# att_context_size can be set for cache-aware streaming models with multiple look-aheads
att_context_size: Optional[list] = None

use_cer: bool = False
tolerance: Optional[float] = None

Expand Down
7 changes: 6 additions & 1 deletion examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,10 @@ class TranscriptionConfig:
# Decoding strategy for RNNT models
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1)

# decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Joint RNNT/CTC models
# decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models
decoder_type: Optional[str] = None
# att_context_size can be set for cache-aware streaming models with multiple look-aheads
att_context_size: Optional[list] = None

# Use this for model-specific changes before transcription
model_change: ModelChangeConfig = ModelChangeConfig()
Expand Down Expand Up @@ -246,6 +248,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
if cfg.decoder_type and cfg.decoder_type != 'rnnt':
raise ValueError('RNNT model only support rnnt decoding!')

if cfg.decoder_type and hasattr(asr_model.encoder, 'set_default_att_context_size'):
asr_model.encoder.set_default_att_context_size(cfg.att_context_size)

# Setup decoding strategy
if hasattr(asr_model, 'change_decoding_strategy'):
if cfg.decoder_type is not None:
Expand Down
4 changes: 3 additions & 1 deletion examples/asr/transcribe_speech_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ class ParallelTranscriptionConfig:
# decoding strategy for RNNT models
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig()

# decoder for hybrid models, must be one of 'ctc', 'rnnt' if not None
# decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models
decoder_type: Optional[str] = None
# att_context_size can be set for cache-aware streaming models with multiple look-aheads
att_context_size: Optional[list] = None

trainer: TrainerConfig = TrainerConfig(devices=-1, accelerator="gpu", strategy="ddp")

Expand Down
7 changes: 7 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,4 +544,11 @@ def list_available_models(cls) -> List[PretrainedModelInfo]:
)
results.append(model)

model = PretrainedModelInfo(
pretrained_model_name="stt_en_fastconformer_hybrid_large_streaming_multi",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_fastconformer_hybrid_large_streaming_multi",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_fastconformer_hybrid_large_streaming_multi/versions/1.20.0/files/stt_en_fastconformer_hybrid_large_streaming_multi.nemo",
)
results.append(model)

return results
5 changes: 5 additions & 0 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from nemo.core.classes.mixins import AccessMixin, adapter_mixins
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types import AcousticEncodedRepresentation, ChannelType, LengthsType, NeuralType, SpectrogramType
from nemo.utils import logging

__all__ = ['ConformerEncoder']

Expand Down Expand Up @@ -778,6 +779,10 @@ def _calc_context_sizes(
return att_context_size_all, att_context_size_all[0], att_context_probs, conv_context_size

def set_default_att_context_size(self, att_context_size):
if att_context_size not in self.att_context_size_all:
logging.warning(
f"att_context_size={att_context_size} is not among the list of the supported look-aheads: {self.att_context_size_all}"
)
self.att_context_size = att_context_size

def setup_streaming_params(
Expand Down

0 comments on commit 6d08704

Please sign in to comment.