From 599f522205d217f6dc97baba9f0347e6bd8e332b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 25 Apr 2023 12:42:29 -0700 Subject: [PATCH] Fix cache aware hybrid bugs (#6466) (#6484) Signed-off-by: hsiehjackson --- ...ech_to_text_cache_aware_streaming_infer.py | 25 ++++++- .../asr/models/hybrid_rnnt_ctc_bpe_models.py | 6 +- .../asr/models/hybrid_rnnt_ctc_models.py | 14 ++-- .../asr/modules/conformer_encoder.py | 2 +- nemo/collections/asr/parts/mixins/mixins.py | 15 +++- .../asr/parts/submodules/subsampling.py | 74 +++++++++++++------ .../asr/parts/utils/streaming_utils.py | 6 +- .../asr/test_asr_hybrid_rnnt_ctc_model_bpe.py | 2 +- 8 files changed, 104 insertions(+), 40 deletions(-) diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py index 589d4c7ec3ee..75912f1c03c1 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py @@ -42,8 +42,14 @@ You may drop the '--debug_mode' and '--compare_vs_offline' to speedup the streaming evaluation. If compare_vs_offline is not used, then significantly larger batch_size can be used. +Setting `--pad_and_drop_preencoded` would perform the caching for all steps including the first step. +It may result in slightly different outputs from the sub-sampling module compared to offline mode for some techniques like striding and sw_striding. +Enabling it would make it easier to export the model to ONNX. + +# Hybrid ASR models +For Hybrid ASR models which have two decoders, you may select the decoder by --set_decoder DECODER_TYPE, where DECODER_TYPE can be "ctc" or "rnnt". +If decoder is not set, then the default decoder would be used which is the RNNT decoder for Hybrid ASR models. -To best compare output with offline output (i.e. `--compare_vs_offline` is set) `--pad-and-drop-preencoded` should also be set. ## Evaluate a model trained with full context for offline mode @@ -126,6 +132,7 @@ def perform_streaming( transcribed_texts, cache_last_channel_next, cache_last_time_next, + cache_last_channel_len, best_hyp, ) = asr_model.conformer_stream_step( processed_signal=processed_signal, @@ -254,9 +261,16 @@ def main(): "--output_path", type=str, help="path to output file when manifest is used as input", default=None ) parser.add_argument( - "--pad-and-drop-preencoded", + "--pad_and_drop_preencoded", action="store_true", - help="Enables padding the audio input and then dropping the extra steps after the pre-encoding for the first step. It makes the outputs of the downsampling exactly as the offline mode for some techniques like striding.", + help="Enables padding the audio input and then dropping the extra steps after the pre-encoding for all the steps including the the first step. It may make the outputs of the downsampling slightly different from offline mode for some techniques like striding or sw_striding.", + ) + + parser.add_argument( + "--set_decoder", + choices=["ctc", "rnnt"], + default=None, + help="Selects the decoder for Hybrid ASR models which has both the CTC and RNNT decoder. Supported decoders are ['ctc', 'rnnt']", ) args = parser.parse_args() @@ -273,6 +287,11 @@ def main(): asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=args.asr_model) logging.info(asr_model.encoder.streaming_cfg) + if args.set_decoder is not None: + if hasattr(asr_model, "cur_decoder"): + asr_model.change_decoding_strategy(decoder_type=args.set_decoder) + else: + raise ValueError("Decoder cannot get changed for non-Hybrid ASR models.") global autocast if ( diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 170aa3f8001a..104b2eb95524 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -124,7 +124,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): ) # setting the RNNT decoder as the default one - self.use_rnnt_decoder = True + self.cur_decoder = "rnnt" def _setup_dataloader_from_config(self, config: Optional[Dict]): dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config( @@ -375,7 +375,7 @@ def change_vocabulary( logging.info(f"Changed tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str = None): + def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None): """ Changes decoding strategy used during RNNT decoding process. Args: @@ -446,7 +446,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str = with open_dict(self.cfg.aux_ctc.decoding): self.cfg.aux_ctc.decoding = decoding_cfg - self.use_rnnt_decoder = False + self.cur_decoder = "ctc" logging.info( f"Changed decoding strategy of the CTC decoder to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}" ) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index e3acec2c7420..a413eaeed6fa 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -86,7 +86,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): ) # setting the RNNT decoder as the default one - self.use_rnnt_decoder = True + self.cur_decoder = "rnnt" # setting up interCTC loss (from InterCTCMixin) self.setup_interctc(decoder_name='ctc_decoder', loss_name='ctc_loss', wer_name='ctc_wer') @@ -125,7 +125,11 @@ def transcribe( * A list of greedy transcript texts / Hypothesis * An optional list of beam search transcript texts / Hypothesis / NBestHypothesis. """ - if self.use_rnnt_decoder: + if self.cur_decoder not in ["ctc", "rnnt"]: + raise ValueError( + f"{self.cur_decoder} is not supported for cur_decoder. Supported values are ['ctc', 'rnnt']" + ) + if self.cur_decoder == "rnnt": return super().transcribe( paths2audio_files=paths2audio_files, batch_size=batch_size, @@ -307,7 +311,7 @@ def change_vocabulary( logging.info(f"Changed the tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str = None): + def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None): """ Changes decoding strategy used during RNNT decoding process. @@ -319,7 +323,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str = used. If set to 'ctc', it raises error if 'ctc_decoder' is not an attribute of the model. """ if decoder_type is None or decoder_type == 'rnnt': - self.use_rnnt_decoder = True + self.cur_decoder = "rnnt" return super().change_decoding_strategy(decoding_cfg=decoding_cfg) assert decoder_type == 'ctc' and hasattr(self, 'ctc_decoder') @@ -346,7 +350,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str = with open_dict(self.cfg.aux_ctc): self.cfg.aux_ctc.decoding = decoding_cfg - self.use_rnnt_decoder = False + self.cur_decoder = "ctc" logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}") # PTL-specific methods diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 032a16d537ea..0fc0912a8921 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -500,7 +500,7 @@ def forward_for_export( def streaming_post_process(self, rets, keep_all_outputs=True): if len(rets) == 2: - return rets + return rets[0], rets[1], None, None, None (encoded, encoded_len, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len) = rets diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index 92977314d08f..f350dbcd5df0 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -491,8 +491,17 @@ def conformer_stream_step( drop_extra_pre_encoded=drop_extra_pre_encoded, ) - if isinstance(self, asr_models.EncDecCTCModel): - log_probs = self.decoder(encoder_output=encoded) + if isinstance(self, asr_models.EncDecCTCModel) or ( + isinstance(self, asr_models.EncDecHybridRNNTCTCModel) and self.cur_decoder == "ctc" + ): + if hasattr(self, "ctc_decoder"): + decoding = self.ctc_decoding + decoder = self.ctc_decoder + else: + decoding = self.decoding + decoder = self.decoder + + log_probs = decoder(encoder_output=encoded) predictions_tensor = log_probs.argmax(dim=-1, keepdim=False) # Concatenate the previous predictions with the current one to have the full predictions. @@ -517,7 +526,7 @@ def conformer_stream_step( # TODO: make decoding more efficient by avoiding the decoding process from the beginning if return_transcription: - decoded_out = self.decoding.ctc_decoder_predictions_tensor( + decoded_out = decoding.ctc_decoder_predictions_tensor( decoder_outputs=greedy_predictions_concat.unsqueeze(0), decoder_lengths=encoded_len[preds_idx : preds_idx + 1], return_hypotheses=False, diff --git a/nemo/collections/asr/parts/submodules/subsampling.py b/nemo/collections/asr/parts/submodules/subsampling.py index 5c0e937e0d24..c10f85403b25 100644 --- a/nemo/collections/asr/parts/submodules/subsampling.py +++ b/nemo/collections/asr/parts/submodules/subsampling.py @@ -126,42 +126,72 @@ def __init__( self._kernel_size = 3 self._ceil_mode = False - self._left_padding = (self._kernel_size - 1) // 2 - self._right_padding = (self._kernel_size - 1) // 2 + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 # Layer 1 - layers.append( - torch.nn.Conv2d( - in_channels=in_channels, - out_channels=conv_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=self._left_padding, + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) ) - ) in_channels = conv_channels layers.append(activation) for i in range(self._sampling_num - 1): - layers.extend( - [ - torch.nn.Conv2d( + if self.is_causal: + layers.append( + CausalConv2D( in_channels=in_channels, out_channels=in_channels, kernel_size=self._kernel_size, stride=self._stride, - padding=self._left_padding, + padding=None, groups=in_channels, - ), + ) + ) + else: + layers.append( torch.nn.Conv2d( in_channels=in_channels, - out_channels=conv_channels, - kernel_size=1, - stride=1, - padding=0, - groups=1, - ), - ] + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ) + ) + + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + ) ) layers.append(activation) in_channels = conv_channels diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index 622b4fe57478..b824bc18e770 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -1367,9 +1367,10 @@ def __iter__(self): ) if self.buffer_idx == 0 and isinstance(self.streaming_cfg.shift_size, list): - shift_size = self.streaming_cfg.shift_size[0] if self.pad_and_drop_preencoded: shift_size = self.streaming_cfg.shift_size[1] + else: + shift_size = self.streaming_cfg.shift_size[0] else: shift_size = ( self.streaming_cfg.shift_size[1] @@ -1394,9 +1395,10 @@ def __iter__(self): # if there is not enough frames to be used as the pre-encoding cache, zeros would be added zeros_pads = None if self.buffer_idx == 0 and isinstance(self.streaming_cfg.pre_encode_cache_size, list): - cache_pre_encode_num_frames = self.streaming_cfg.pre_encode_cache_size[0] if self.pad_and_drop_preencoded: cache_pre_encode_num_frames = self.streaming_cfg.pre_encode_cache_size[1] + else: + cache_pre_encode_num_frames = self.streaming_cfg.pre_encode_cache_size[0] cache_pre_encode = torch.zeros( (audio_chunk.size(0), self.input_features, cache_pre_encode_num_frames), device=audio_chunk.device, diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py index e59353102c39..0f3611f95153 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py @@ -306,4 +306,4 @@ def test_decoding_change(self, hybrid_asr_model): assert hybrid_asr_model.ctc_decoding.preserve_alignments is True assert hybrid_asr_model.ctc_decoding.compute_timestamps is True - assert hybrid_asr_model.use_rnnt_decoder is False + assert hybrid_asr_model.cur_decoder == "ctc"