diff --git a/examples/asr/conf/conformer/cache_aware_streaming/conformer_ctc_bpe_streaming.yaml b/examples/asr/conf/conformer/cache_aware_streaming/conformer_ctc_bpe_streaming.yaml index 98f23458cd86..32afd919a454 100644 --- a/examples/asr/conf/conformer/cache_aware_streaming/conformer_ctc_bpe_streaming.yaml +++ b/examples/asr/conf/conformer/cache_aware_streaming/conformer_ctc_bpe_streaming.yaml @@ -103,10 +103,16 @@ model: # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one - # for chunked_limited you may calculate the look-ahead or right context by the following formula: # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 27*4*0.01=1.08s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[140,27],[140,13],[140,2],[140,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] att_context_size: [140, 27] # -1 means unlimited context att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null xscaling: true # scales up the input embeddings by sqrt(d_model) untie_biases: true # unties the biases of the TransformerXL layers diff --git a/examples/asr/conf/conformer/cache_aware_streaming/conformer_transducer_bpe_streaming.yaml b/examples/asr/conf/conformer/cache_aware_streaming/conformer_transducer_bpe_streaming.yaml index 9d6e3a54d9fe..d55e5f927b2e 100644 --- a/examples/asr/conf/conformer/cache_aware_streaming/conformer_transducer_bpe_streaming.yaml +++ b/examples/asr/conf/conformer/cache_aware_streaming/conformer_transducer_bpe_streaming.yaml @@ -113,10 +113,16 @@ model: # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one - # for chunked_limited you may calculate the look-ahead or right context by the following formula: - # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 27*4*0.01=1.08s + # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[140,27],[140,13],[140,2],[140,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] att_context_size: [140, 27] # -1 means unlimited context att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null xscaling: true # scales up the input embeddings by sqrt(d_model) untie_biases: true # unties the biases of the TransformerXL layers diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml index c68b30a33d5a..749216b1925d 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml @@ -97,10 +97,17 @@ model: # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one - # for chunked_limited you may calculate the look-ahead or right context by the following formula: # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[70,13],[70,6],[70,1],[70,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] att_context_size: [70, 13] # -1 means unlimited context att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null + xscaling: true # scales up the input embeddings by sqrt(d_model) pos_emb_max_len: 5000 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml index 654895ec065d..17345119c529 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml @@ -100,11 +100,19 @@ model: n_heads: 8 # may need to be lower for smaller d_models # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention - # for att_context_style=regular, the right context is recommended to be a small number around 0 to 2 as multiple-layers may increase the effective right context too large + # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[70,13],[70,6],[70,1],[70,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] att_context_size: [70, 13] # -1 means unlimited context att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null + xscaling: true # scales up the input embeddings by sqrt(d_model) pos_emb_max_len: 5000 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml index 5f223061a420..dbd036458cb8 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml @@ -102,10 +102,17 @@ model: # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one - # for chunked_limited you may calculate the look-ahead or right context by the following formula: # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[70,13],[70,6],[70,1],[70,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] att_context_size: [70, 13] # -1 means unlimited context att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null + xscaling: true # scales up the input embeddings by sqrt(d_model) pos_emb_max_len: 5000 @@ -191,9 +198,9 @@ model: loss_name: "default" warprnnt_numba_kwargs: # FastEmit regularization: https://arxiv.org/abs/2010.11148 - # You may enable FastEmit to reduce the latency of the model for streaming - # It also helps to improve the accuracy of the model in streaming mode - fastemit_lambda: 1e-3 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + # You may enable FastEmit to increase the accuracy and reduce the latency of the model for streaming + # You may set it to lower values like 1e-3 for models with larger right context + fastemit_lambda: 5e-3 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. optim: diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml index 68a78ba60aac..50f73d35ca75 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml @@ -106,11 +106,19 @@ model: n_heads: 8 # may need to be lower for smaller d_models # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention - # for att_context_style=regular, the right context is recommended to be a small number around 0 to 2 as multiple-layers may increase the effective right context too large + # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[70,13],[70,6],[70,1],[70,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] att_context_size: [70, 13] # -1 means unlimited context att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null + xscaling: true # scales up the input embeddings by sqrt(d_model) pos_emb_max_len: 5000 @@ -196,9 +204,9 @@ model: loss_name: "default" warprnnt_numba_kwargs: # FastEmit regularization: https://arxiv.org/abs/2010.11148 - # You may enable FastEmit to reduce the latency of the model for streaming - # It also helps to improve the accuracy of the model in streaming mode - fastemit_lambda: 1e-3 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + # You may enable FastEmit to increase the accuracy and reduce the latency of the model for streaming + # You may set it to lower values like 1e-3 for models with larger right context + fastemit_lambda: 5e-3 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. optim: diff --git a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml index 8b7a2ce7b39d..26dabaa039fe 100644 --- a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml +++ b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml @@ -8,6 +8,8 @@ # FastConformer-CTC's architecture config: NeMo/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml # FastConformer-Transducer's architecture config, along with the optimal batch size and precision: NeMo/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml +# Note: if training loss does not converge, you may increase warm-up to 20K. + name: "FastConformer-Hybrid-Transducer-CTC-BPE-Streaming" model: @@ -106,8 +108,15 @@ model: # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[70,13],[70,6],[70,1],[70,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] att_context_size: [70, 13] # -1 means unlimited context att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null xscaling: true # scales up the input embeddings by sqrt(d_model) pos_emb_max_len: 5000 @@ -206,9 +215,9 @@ model: loss_name: "default" warprnnt_numba_kwargs: # FastEmit regularization: https://arxiv.org/abs/2010.11148 - # You may enable FastEmit to reduce the latency of the model for streaming - # It also helps to improve the accuracy of the model in streaming mode - fastemit_lambda: 1e-3 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + # You may enable FastEmit to increase the accuracy and reduce the latency of the model for streaming + # You may set it to lower values like 1e-3 for models with larger right context + fastemit_lambda: 5e-3 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. optim: diff --git a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml index a24829b50788..d8362636f04a 100644 --- a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml +++ b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml @@ -8,6 +8,8 @@ # FastConformer-CTC's architecture config: NeMo/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml # FastConformer-Transducer's architecture config, along with the optimal batch size and precision: NeMo/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml +# Note: if training loss does not converge, you may increase warm-up to 20K. + name: "FastConformer-Hybrid-Transducer-CTC-Char-Streaming" model: @@ -111,8 +113,15 @@ model: # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[70,13],[70,6],[70,1],[70,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] att_context_size: [70, 13] # -1 means unlimited context att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null xscaling: true # scales up the input embeddings by sqrt(d_model) pos_emb_max_len: 5000 @@ -211,9 +220,9 @@ model: loss_name: "default" warprnnt_numba_kwargs: # FastEmit regularization: https://arxiv.org/abs/2010.11148 - # You may enable FastEmit to reduce the latency of the model for streaming - # It also helps to improve the accuracy of the model in streaming mode - fastemit_lambda: 1e-3 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + # You may enable FastEmit to increase the accuracy and reduce the latency of the model for streaming + # You may set it to lower values like 1e-3 for models with larger right context + fastemit_lambda: 5e-3 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. optim: diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 7c786f9c9720..74c255741039 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import random from collections import OrderedDict from dataclasses import dataclass from typing import List, Optional, Set @@ -89,9 +90,13 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): Defaults to 5000 n_heads (int): number of heads in multi-headed attention layers Defaults to 4. - att_context_size (List[int]): List of 2 ints corresponding to left and right attention context sizes, - or None for full context. - Defaults to None. + att_context_size (List[Union[List[int],int]]): specifies the context sizes on each side. Each context size should be a list of two integers like [100,100]. + A list of context sizes like [[100,100],[100,50]] can also be passed. -1 means unlimited context. + Defaults to [-1,-1] + att_context_probs (List[float]): a list of probabilities of each one of the att_context_size when a list of them is passed. If not specified, uniform distribution is being used. + Defaults to None + att_context_style (str): 'regular' or 'chunked_limited'. + Defaults to 'regular' xscaling (bool): enables scaling the inputs to the multi-headed attention layers by sqrt(d_model) Defaults to True. untie_biases (bool): whether to not share (untie) the bias weights between layers of Transformer-XL @@ -100,6 +105,11 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): Defaults to 31. conv_norm_type (str): the type of the normalization in the convolutional modules Defaults to 'batch_norm'. + conv_context_size (list): it can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size. + None means [(conv_kernel_size-1)//2, (conv_kernel_size-1)//2], and 'causal' means [(conv_kernel_size-1), 0]. + Defaults to None. + conv_dual_mode (bool): specifies if convolution should be dual mode when dual_offline mode is being used. When enables, the left half of the convolution kernel would get masked in streaming cases. + Defaults to False dropout (float): the dropout rate used in all layers except the attention layers Defaults to 0.1. dropout_pre_encoder (float): the dropout rate used before the encoder @@ -256,6 +266,7 @@ def __init__( self_attention_model='rel_pos', n_heads=4, att_context_size=None, + att_context_probs=None, att_context_style='regular', xscaling=True, untie_biases=True, @@ -279,7 +290,6 @@ def __init__( self.d_model = d_model self.n_layers = n_layers self._feat_in = feat_in - self.scale = math.sqrt(self.d_model) self.att_context_style = att_context_style self.subsampling_factor = subsampling_factor self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor @@ -289,51 +299,19 @@ def __init__( self.global_attn_separate = global_attn_separate self.global_tokens_spacing = global_tokens_spacing - if att_context_size: - self.att_context_size = list(att_context_size) - else: - self.att_context_size = [-1, -1] - - if isinstance(conv_context_size, ListConfig): - conv_context_size = list(conv_context_size) - - if conv_context_size is not None: - if ( - not isinstance(conv_context_size, list) - and not isinstance(conv_context_size, str) - and not isinstance(conv_context_size, ListConfig) - ): - raise ValueError( - f"Invalid conv_context_size! It should be the string 'causal' or a list of two integers." - ) - if conv_context_size == "causal": - conv_context_size = [conv_kernel_size - 1, 0] - else: - if conv_context_size[0] + conv_context_size[1] + 1 != conv_kernel_size: - raise ValueError(f"Invalid conv_context_size: {self.conv_context_size}!") - else: - conv_context_size = [(conv_kernel_size - 1) // 2, (conv_kernel_size - 1) // 2] - self.conv_context_size = conv_context_size - - if att_context_style == "chunked_limited": - # the left context for self-attention in chunked_limited mode should be dividable by the right context - # right context=att_context_size[1]+1, and left_context=self.att_context_size[0] - if self.att_context_size[0] > 0 and self.att_context_size[0] % (self.att_context_size[1] + 1) > 0: - raise ValueError("att_context_size[0] % (att_context_size[1] + 1) should be zero!") - if self.att_context_size[1] < 0: - raise ValueError("Right context can not be unlimited for chunked_limited style!") - self.chunk_size = self.att_context_size[1] + 1 - - # left_chunks_num specifies the number of chunks to be visible by each chunk on the left side - if self.att_context_size[0] >= 0: - self.left_chunks_num = self.att_context_size[0] // self.chunk_size - else: - self.left_chunks_num = 100000 - - elif att_context_style == "regular": - self.chunk_size = None - else: - raise ValueError("Invalid att_context_style!") + # Setting up the att_context_size + ( + self.att_context_size_all, + self.att_context_size, + self.att_context_probs, + self.conv_context_size, + ) = self._calc_context_sizes( + att_context_style=att_context_style, + att_context_size=att_context_size, + att_context_probs=att_context_probs, + conv_context_size=conv_context_size, + conv_kernel_size=conv_kernel_size, + ) if xscaling: self.xscale = math.sqrt(d_model) @@ -379,6 +357,7 @@ def __init__( self._feat_out = d_model + # Biases for relative positional encoding if not untie_biases and self_attention_model == "rel_pos": d_head = d_model // n_heads pos_bias_u = nn.Parameter(torch.Tensor(n_heads, d_head)) @@ -389,8 +368,8 @@ def __init__( pos_bias_u = None pos_bias_v = None + # Positional encodings self.pos_emb_max_len = pos_emb_max_len - self.att_mask = None if self_attention_model == "rel_pos": self.pos_enc = RelPositionalEncoding( d_model=d_model, @@ -458,51 +437,6 @@ def __init__( # will be set in self.forward() if defined in AccessMixin config self.interctc_capture_at_layers = None - def update_max_seq_length(self, seq_length: int, device): - # Find global max audio length across all nodes - if torch.distributed.is_initialized(): - global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device) - - # Update across all ranks in the distributed system - torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX) - - seq_length = global_max_len.to(torch.int64).item() - - if seq_length > self.max_audio_length: - self.set_max_audio_length(seq_length) - - def set_max_audio_length(self, max_audio_length): - """ - Sets maximum input length. - Pre-calculates internal seq_range mask. - """ - self.max_audio_length = max_audio_length - device = next(self.parameters()).device - self.pos_enc.extend_pe(max_audio_length, device) - - if self.self_attention_model != "rel_pos_local_attn": - att_mask = torch.ones(1, max_audio_length, max_audio_length, dtype=torch.bool, device=device) - if self.chunk_size is None: - if self.att_context_size[0] >= 0: - att_mask = att_mask.triu(diagonal=-self.att_context_size[0]) - if self.att_context_size[1] >= 0: - att_mask = att_mask.tril(diagonal=self.att_context_size[1]) - else: - chunk_idx = torch.arange(0, max_audio_length, dtype=torch.int64, device=att_mask.device) - chunk_idx = torch.div(chunk_idx, self.chunk_size, rounding_mode="trunc") - diff_chunks = chunk_idx.unsqueeze(1) - chunk_idx.unsqueeze(0) - chunked_limited_mask = torch.logical_and( - torch.le(diff_chunks, self.left_chunks_num), torch.ge(diff_chunks, 0) - ) - att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0)) - - if hasattr(self, 'att_mask'): - self.att_mask = att_mask - else: - self.register_buffer('att_mask', att_mask, persistent=False) - else: - self.att_mask = None - def forward_for_export( self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): @@ -565,17 +499,24 @@ def forward_internal( self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device) - max_audio_length = audio_signal.size(-1) if length is None: length = audio_signal.new_full( - (audio_signal.size(0),), max_audio_length, dtype=torch.int64, device=audio_signal.device + (audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device ) if cache_last_time is not None: cache_last_time_next = torch.zeros_like(cache_last_time) else: cache_last_time_next = None + + # select a random att_context_size with the distribution specified by att_context_probs during training + # for non-validation cases like test, validation or inference, it uses the first mode in self.att_context_size + if self.training and len(self.att_context_size_all) > 1: + cur_att_context_size = random.choices(self.att_context_size_all, weights=self.att_context_probs)[0] + else: + cur_att_context_size = self.att_context_size + audio_signal = torch.transpose(audio_signal, 1, 2) if isinstance(self.pre_encode, nn.Linear): @@ -588,11 +529,10 @@ def forward_internal( audio_signal = audio_signal[:, self.streaming_cfg.drop_extra_pre_encoded :, :] length = (length - self.streaming_cfg.drop_extra_pre_encoded).clamp(min=0) - max_audio_length = audio_signal.size(1) - if self.reduction_position is not None and cache_last_channel is not None: raise ValueError("Caching with reduction feature is not supported yet!") + max_audio_length = audio_signal.size(1) if cache_last_channel is not None: cache_len = self.streaming_cfg.last_channel_cache_size cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size @@ -606,17 +546,20 @@ def forward_internal( cache_len = 0 offset = None - if self.self_attention_model == 'abs_pos': - audio_signal, pos_emb = self.pos_enc(x=audio_signal) - else: - audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) + audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) # Create the self-attention and padding masks - pad_mask, att_mask = self._create_masks(max_audio_length, padding_length, offset, audio_signal.device) + pad_mask, att_mask = self._create_masks( + att_context_size=cur_att_context_size, + padding_length=padding_length, + max_audio_length=max_audio_length, + offset=offset, + device=audio_signal.device, + ) if cache_last_channel is not None: pad_mask = pad_mask[:, cache_len:] - if self.att_mask is not None: + if att_mask is not None: att_mask = att_mask[:, cache_len:] for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)): @@ -650,7 +593,13 @@ def forward_internal( # Don't update the audio_signal here because then it will again scale the audio_signal # and cause an increase in the WER _, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) - pad_mask, att_mask = self._create_masks(max_audio_length, length, offset, audio_signal.device) + pad_mask, att_mask = self._create_masks( + att_context_size=cur_att_context_size, + padding_length=length, + max_audio_length=max_audio_length, + offset=offset, + device=audio_signal.device, + ) # saving tensors if required for interctc loss if self.is_access_enabled(): @@ -687,7 +636,60 @@ def forward_internal( else: return audio_signal, length - def _create_masks(self, max_audio_length, padding_length, offset, device): + def update_max_seq_length(self, seq_length: int, device): + # Find global max audio length across all nodes + if torch.distributed.is_initialized(): + global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device) + + # Update across all ranks in the distributed system + torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX) + + seq_length = global_max_len.int().item() + + if seq_length > self.max_audio_length: + self.set_max_audio_length(seq_length) + + def set_max_audio_length(self, max_audio_length): + """ + Sets maximum input length. + Pre-calculates internal seq_range mask. + """ + self.max_audio_length = max_audio_length + device = next(self.parameters()).device + self.pos_enc.extend_pe(max_audio_length, device) + + def _create_masks(self, att_context_size, padding_length, max_audio_length, offset, device): + if self.self_attention_model != "rel_pos_local_attn": + att_mask = torch.ones(1, max_audio_length, max_audio_length, dtype=torch.bool, device=device) + + if self.att_context_style == "regular": + if att_context_size[0] >= 0: + att_mask = att_mask.triu(diagonal=-att_context_size[0]) + if att_context_size[1] >= 0: + att_mask = att_mask.tril(diagonal=att_context_size[1]) + elif self.att_context_style == "chunked_limited": + # When right context is unlimited, just the left side of the masking need to get updated + if att_context_size[1] == -1: + if att_context_size[0] >= 0: + att_mask = att_mask.triu(diagonal=-att_context_size[0]) + else: + chunk_size = att_context_size[1] + 1 + # left_chunks_num specifies the number of chunks to be visible by each chunk on the left side + if att_context_size[0] >= 0: + left_chunks_num = att_context_size[0] // chunk_size + else: + left_chunks_num = 10000 + + chunk_idx = torch.arange(0, max_audio_length, dtype=torch.int, device=att_mask.device) + chunk_idx = torch.div(chunk_idx, chunk_size, rounding_mode="trunc") + diff_chunks = chunk_idx.unsqueeze(1) - chunk_idx.unsqueeze(0) + chunked_limited_mask = torch.logical_and( + torch.le(diff_chunks, left_chunks_num), torch.ge(diff_chunks, 0) + ) + att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0)) + else: + att_mask = None + # pad_mask is the masking to be used to ignore paddings pad_mask = torch.arange(0, max_audio_length, device=device).expand( padding_length.size(0), -1 @@ -697,24 +699,19 @@ def _create_masks(self, max_audio_length, padding_length, offset, device): pad_mask_off = torch.arange(0, max_audio_length, device=device).expand( padding_length.size(0), -1 ) >= offset.unsqueeze(-1) - pad_mask = pad_mask_off.logical_and(pad_mask) - if self.att_mask is not None: + if att_mask is not None: # pad_mask_for_att_mask is the mask which helps to ignore paddings pad_mask_for_att_mask = pad_mask.unsqueeze(1).repeat([1, max_audio_length, 1]) pad_mask_for_att_mask = torch.logical_and(pad_mask_for_att_mask, pad_mask_for_att_mask.transpose(1, 2)) # att_mask is the masking to be used by the MHA layers to ignore the tokens not supposed to be visible - att_mask = self.att_mask[:, :max_audio_length, :max_audio_length] + att_mask = att_mask[:, :max_audio_length, :max_audio_length] # paddings should also get ignored, so pad_mask_for_att_mask is used to ignore their corresponding scores att_mask = torch.logical_and(pad_mask_for_att_mask, att_mask.to(pad_mask_for_att_mask.device)) - att_mask = ~att_mask - else: - att_mask = None pad_mask = ~pad_mask - return pad_mask, att_mask def enable_pad_mask(self, on=True): @@ -723,8 +720,64 @@ def enable_pad_mask(self, on=True): self.use_pad_mask = on return mask + def _calc_context_sizes( + self, att_context_size, att_context_probs, att_context_style, conv_context_size, conv_kernel_size + ): + # convert att_context_size to a standard list of lists + if att_context_size: + att_context_size_all = list(att_context_size) + if isinstance(att_context_size_all[0], int): + att_context_size_all = [att_context_size_all] + for i, att_cs in enumerate(att_context_size_all): + if isinstance(att_cs, ListConfig): + att_context_size_all[i] = list(att_cs) + if att_context_style == "chunked_limited": + if att_cs[0] > 0 and att_cs[0] % (att_cs[1] + 1) > 0: + raise ValueError(f"att_context_size[{i}][0] % (att_context_size[{i}][1] + 1) should be zero!") + if att_cs[1] < 0 and len(att_context_size_all) <= 1: + raise ValueError( + f"Right context (att_context_size[{i}][1]) can not be unlimited for chunked_limited style!" + ) + else: + att_context_size_all = [[-1, -1]] + + if att_context_probs: + if len(att_context_probs) != len(att_context_size_all): + raise ValueError("The size of the att_context_probs should be the same as att_context_size.") + att_context_probs = list(att_context_probs) + if sum(att_context_probs) != 1: + raise ValueError( + "The sum of numbers in att_context_probs should be equal to one to be a distribution." + ) + else: + att_context_probs = [1.0 / len(att_context_size_all)] * len(att_context_size_all) + + if conv_context_size is not None: + if isinstance(conv_context_size, ListConfig): + conv_context_size = list(conv_context_size) + if not isinstance(conv_context_size, list) and not isinstance(conv_context_size, str): + raise ValueError( + f"Invalid conv_context_size! It should be the string 'causal' or a list of two integers." + ) + if conv_context_size == "causal": + conv_context_size = [conv_kernel_size - 1, 0] + else: + if conv_context_size[0] + conv_context_size[1] + 1 != conv_kernel_size: + raise ValueError(f"Invalid conv_context_size: {self.conv_context_size}!") + else: + conv_context_size = [(conv_kernel_size - 1) // 2, (conv_kernel_size - 1) // 2] + return att_context_size_all, att_context_size_all[0], att_context_probs, conv_context_size + + def set_default_att_context_size(self, att_context_size): + self.att_context_size = att_context_size + def setup_streaming_params( - self, chunk_size: int = None, shift_size: int = None, left_chunks: int = None, max_context: int = 10000 + self, + chunk_size: int = None, + shift_size: int = None, + left_chunks: int = None, + att_context_size: list = None, + max_context: int = 10000, ): """ This function sets the needed values and parameters to perform streaming. The configuration would be stored in self.streaming_cfg. @@ -737,25 +790,28 @@ def setup_streaming_params( Defaults to -1 (means feat_out is d_model) """ streaming_cfg = CacheAwareStreamingConfig() + + # When att_context_size is not specified, it uses the default_att_context_size + if att_context_size is None: + att_context_size = self.att_context_size + if chunk_size is not None: if chunk_size < 1: raise ValueError("chunk_size needs to be a number larger or equal to one.") lookahead_steps = chunk_size - 1 streaming_cfg.cache_drop_size = chunk_size - shift_size elif self.att_context_style == "chunked_limited": - lookahead_steps = self.att_context_size[1] + lookahead_steps = att_context_size[1] streaming_cfg.cache_drop_size = 0 elif self.att_context_style == "regular": - lookahead_steps = self.att_context_size[1] * self.n_layers + self.conv_context_size[1] * self.n_layers + lookahead_steps = att_context_size[1] * self.n_layers + self.conv_context_size[1] * self.n_layers streaming_cfg.cache_drop_size = lookahead_steps else: streaming_cfg.cache_drop_size = 0 lookahead_steps = None if chunk_size is None: - streaming_cfg.last_channel_cache_size = ( - self.att_context_size[0] if self.att_context_size[0] >= 0 else max_context - ) + streaming_cfg.last_channel_cache_size = att_context_size[0] if att_context_size[0] >= 0 else max_context else: if left_chunks is None: raise ValueError("left_chunks can not be None when chunk_size is set.") @@ -878,9 +934,9 @@ def change_attention_model( 'rel_pos_local_attn': relative positional embedding and Transformer-XL with local attention using overlapping windows. Attention context is determined by att_context_size parameter. 'abs_pos': absolute positional embedding and Transformer - If None is provided, the self_attention_model isn't changed. Defauts to None. + If None is provided, the self_attention_model isn't changed. Defaults to None. att_context_size (List[int]): List of 2 ints corresponding to left and right attention context sizes, - or None to keep as it is. Defauts to None. + or None to keep as it is. Defaults to None. update_config (bool): Whether to update the config or not with the new attention model. Defaults to True. device (torch.device): If provided, new layers will be moved to the device. @@ -889,19 +945,16 @@ def change_attention_model( if att_context_size: att_context_size = list(att_context_size) - elif hasattr(self._cfg, "att_context_size"): - att_context_size = self._cfg.att_context_size else: att_context_size = self.att_context_size if self_attention_model is None: - self_attention_model = self._cfg.self_attention_model + self_attention_model = self.self_attention_model if self_attention_model == 'rel_pos_local_attn' and max(att_context_size) <= 0: raise ValueError("When using local attention, context size must be set > 0") if self_attention_model == "rel_pos": - self.att_mask = None new_pos_enc = RelPositionalEncoding( d_model=self._cfg.d_model, dropout_rate=self._cfg.dropout, @@ -938,7 +991,6 @@ def change_attention_model( for name, m in self.named_modules(): if type(m) == ConformerLayer: - if self_attention_model == 'rel_pos': new_attn = RelPositionMultiHeadAttention( n_head=self._cfg.n_heads, diff --git a/nemo/collections/asr/modules/squeezeformer_encoder.py b/nemo/collections/asr/modules/squeezeformer_encoder.py index 952c9b53d233..a887abd19ebb 100644 --- a/nemo/collections/asr/modules/squeezeformer_encoder.py +++ b/nemo/collections/asr/modules/squeezeformer_encoder.py @@ -149,7 +149,6 @@ def __init__( d_ff = d_model * ff_expansion_factor self.d_model = d_model self._feat_in = feat_in - self.scale = math.sqrt(self.d_model) if att_context_size: self.att_context_size = att_context_size else: diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index 40baf1141bd3..b7356ffe87e4 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -888,17 +888,19 @@ def extend_pe(self, length, device): positions = torch.arange(0, length, dtype=torch.float32, device=device).unsqueeze(1) self.create_pe(positions=positions) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor, cache_len=0): """Adds positional encoding. Args: x (torch.Tensor): Input. Its shape is (batch, time, feature_size) + cache_len (int): the size of the cache which is used to shift positions Returns: x+pos_emb (torch.Tensor): Its shape is (batch, time, feature_size) pos_emb (torch.Tensor): Its shape is (1, time, feature_size) """ + input_len = x.size(1) + cache_len if self.xscale: x = x * self.xscale - pos_emb = self.pe[:, : x.size(1)] + pos_emb = self.pe[:, :input_len] if self.dropout_emb: pos_emb = self.dropout_emb(pos_emb) x = x + pos_emb diff --git a/tests/collections/nlp/test_huggingface.py b/tests/collections/nlp/test_huggingface.py index cfe2845caa9b..0ad7b5850475 100644 --- a/tests/collections/nlp/test_huggingface.py +++ b/tests/collections/nlp/test_huggingface.py @@ -85,12 +85,13 @@ def test_get_pretrained_chinese_bert_wwm_model(self): tokenizer = get_tokenizer(tokenizer_name=model_name) assert isinstance(tokenizer, AutoTokenizer) - @pytest.mark.with_downloads() - @pytest.mark.unit - def test_get_pretrained_arabic_model(self): - model_name = 'asafaya/bert-base-arabic' - self.omega_conf.language_model.pretrained_model_name = model_name - model = nemo_nlp.modules.get_lm_model(cfg=self.omega_conf) - assert isinstance(model, nemo_nlp.modules.BertModule) - tokenizer = get_tokenizer(tokenizer_name=model_name) - assert isinstance(tokenizer, AutoTokenizer) + # model is not on HF anymore + # @pytest.mark.with_downloads() + # @pytest.mark.unit + # def test_get_pretrained_arabic_model(self): + # model_name = 'asafaya/bert-base-arabic' + # self.omega_conf.language_model.pretrained_model_name = model_name + # model = nemo_nlp.modules.get_lm_model(cfg=self.omega_conf) + # assert isinstance(model, nemo_nlp.modules.BertModule) + # tokenizer = get_tokenizer(tokenizer_name=model_name) + # assert isinstance(tokenizer, AutoTokenizer)