Skip to content

Commit

Permalink
Streaming conformer CTC export (#5837)
Browse files Browse the repository at this point in the history
* cache-aware streaming export

Test onnx streaming conformer ctc WER

Constant att cache width with len param

Remove some extra functions in cache_aware runner

transpose cache so that batch is first for trt

Signed-off-by: Greg Clark <grclark@nvidia.com>

* fix export for full-context conformer

* WIP trying to improve onnx perf

Signed-off-by: Greg Clark <grclark@nvidia.com>

* Adding test scripts

Signed-off-by: Greg Clark <grclark@nvidia.com>

* More perf testing script

Signed-off-by: Greg Clark <grclark@nvidia.com>

* Updates for jit torch_tensorrt tracing

Signed-off-by: Greg Clark <grclark@nvidia.com>

* Fixed trace warnings

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Rearranging tests

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Fixing non-caching case

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* testing

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Fixed channel cache length issue

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* cache-aware streaming export

Test onnx streaming conformer ctc WER

Constant att cache width with len param

Remove some extra functions in cache_aware runner

transpose cache so that batch is first for trt

Signed-off-by: Greg Clark <grclark@nvidia.com>

* fix export for full-context conformer

* WIP trying to improve onnx perf

Signed-off-by: Greg Clark <grclark@nvidia.com>

* Adding test scripts

Signed-off-by: Greg Clark <grclark@nvidia.com>

* More perf testing script

Signed-off-by: Greg Clark <grclark@nvidia.com>

* Updates for jit torch_tensorrt tracing

Signed-off-by: Greg Clark <grclark@nvidia.com>

* stash

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Reverting non-essential changes

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Offset=None case

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Remove test scripts

Signed-off-by: Greg Clark <grclark@nvidia.com>

* Clean up speech_to_text_cache_aware_streaming_infer

Signed-off-by: Greg Clark <grclark@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert pad -> constant_pad_nd

Signed-off-by: Greg Clark <grclark@nvidia.com>

* conformer-encoder set window_size from streaming_cfg

Signed-off-by: Greg Clark <grclark@nvidia.com>

* Fixes for working export(), using more constants

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Optional rand init for cahce

Signed-off-by: Greg Clark <grclark@nvidia.com>

* Folding update_cache with constants

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* More folding

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Reducing diff #1

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Reducing diff #2

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Reducing diff #3

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Fixed unit tests, more reverts

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Export fixes

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Reverted slice changes that ruined ONNX perf

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>

* Adding back keep_all_outputs and drop_extra_preencoded

Signed-off-by: Greg Clark <grclark@nvidia.com>

* Fix export

Signed-off-by: Greg Clark <grclark@nvidia.com>

---------

Signed-off-by: Greg Clark <grclark@nvidia.com>
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
Co-authored-by: Boris Fomitchev <bfomitchev@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Vahid Noroozi <VahidooX@users.noreply.github.com>
Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>
4 people authored and hsiehjackson committed Jun 2, 2023
1 parent 8448ee3 commit d1f1c9e
Showing 12 changed files with 284 additions and 127 deletions.
Original file line number Diff line number Diff line change
@@ -100,15 +100,17 @@ def extract_transcriptions(hyps):
return transcriptions


def calc_drop_extra_pre_encoded(asr_model, step_num):
def calc_drop_extra_pre_encoded(asr_model, step_num, pad_and_drop_preencoded):
# for the first step there is no need to drop any tokens after the downsampling as no caching is being used
if step_num == 0:
if step_num == 0 and not pad_and_drop_preencoded:
return 0
else:
return asr_model.encoder.streaming_cfg.drop_extra_pre_encoded


def perform_streaming(asr_model, streaming_buffer, compare_vs_offline=False, debug_mode=False):
def perform_streaming(
asr_model, streaming_buffer, compare_vs_offline=False, debug_mode=False, pad_and_drop_preencoded=False
):
batch_size = len(streaming_buffer.streams_length)
if compare_vs_offline:
# would pass the whole audio at once through the model like offline mode in order to compare the results with the stremaing mode
@@ -133,7 +135,9 @@ def perform_streaming(asr_model, streaming_buffer, compare_vs_offline=False, deb
else:
final_offline_tran = None

cache_last_channel, cache_last_time = asr_model.encoder.get_initial_cache_state(batch_size=batch_size)
cache_last_channel, cache_last_time, cache_last_channel_len = asr_model.encoder.get_initial_cache_state(
batch_size=batch_size
)

previous_hypotheses = None
streaming_buffer_iter = iter(streaming_buffer)
@@ -150,16 +154,20 @@ def perform_streaming(asr_model, streaming_buffer, compare_vs_offline=False, deb
transcribed_texts,
cache_last_channel,
cache_last_time,
cache_last_channel_len,
previous_hypotheses,
) = asr_model.conformer_stream_step(
processed_signal=chunk_audio,
processed_signal_length=chunk_lengths,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
keep_all_outputs=streaming_buffer.is_buffer_empty(),
previous_hypotheses=previous_hypotheses,
previous_pred_out=pred_out_stream,
drop_extra_pre_encoded=calc_drop_extra_pre_encoded(asr_model, step_num),
drop_extra_pre_encoded=calc_drop_extra_pre_encoded(
asr_model, step_num, pad_and_drop_preencoded
),
return_transcription=True,
)

@@ -243,6 +251,9 @@ def main():
parser.add_argument(
"--output_path", type=str, help="path to output file when manifest is used as input", default=None
)
parser.add_argument(
"--pad-and-drop-preencoded", action="store_true", help="pad first audio chunk and always drop preencoded"
)

args = parser.parse_args()
if (args.audio_file is None and args.manifest_file is None) or (
@@ -312,14 +323,21 @@ def autocast():
else:
online_normalization = False

streaming_buffer = CacheAwareStreamingAudioBuffer(model=asr_model, online_normalization=online_normalization)
streaming_buffer = CacheAwareStreamingAudioBuffer(
model=asr_model,
online_normalization=online_normalization,
pad_and_drop_preencoded=args.pad_and_drop_preencoded,
)
if args.audio_file is not None:
# stream a single audio file
processed_signal, processed_signal_length, stream_id = streaming_buffer.append_audio_file(
args.audio_file, stream_id=-1
)
perform_streaming(
asr_model=asr_model, streaming_buffer=streaming_buffer, compare_vs_offline=args.compare_vs_offline
asr_model=asr_model,
streaming_buffer=streaming_buffer,
compare_vs_offline=args.compare_vs_offline,
pad_and_drop_preencoded=args.pad_and_drop_preencoded,
)
else:
# stream audio files in a manifest file in batched mode
@@ -351,6 +369,7 @@ def autocast():
streaming_buffer=streaming_buffer,
compare_vs_offline=args.compare_vs_offline,
debug_mode=args.debug_mode,
pad_and_drop_preencoded=args.pad_and_drop_preencoded,
)
all_streaming_tran.extend(streaming_tran)
if args.compare_vs_offline:
54 changes: 43 additions & 11 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import AccessMixin
from nemo.core.utils.neural_type_utils import get_io_names
from nemo.utils import logging, model_utils
from nemo.utils.cast_utils import cast_all

@@ -156,7 +157,19 @@ def input_module(self):
def output_module(self):
return self.decoder

def forward_for_export(self, input, length=None, cache_last_channel=None, cache_last_time=None):
@property
def output_names(self):
otypes = self.output_module.output_types
if hasattr(self.input_module, 'export_cache_support') and self.input_module.export_cache_support:
in_types = self.input_module.output_types
otypes = {n: t for (n, t) in list(otypes.items())[:1]}
for (n, t) in list(in_types.items())[1:]:
otypes[n] = t
return get_io_names(otypes, self.disabled_deployment_output_names)

def forward_for_export(
self, input, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
):
"""
This forward is used when we need to export the model to ONNX format.
Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models.
@@ -175,34 +188,53 @@ def forward_for_export(self, input, length=None, cache_last_channel=None, cache_
"""
if hasattr(self.input_module, 'forward_for_export'):
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module.forward_for_export(input, length)
encoder_output = self.input_module.forward_for_export(audio_signal=input, length=length)
else:
encoder_output = self.input_module.forward_for_export(
input, length, cache_last_channel, cache_last_time
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
else:
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module(input, length)
encoder_output = self.input_module(audio_signal=input, length=length)
else:
encoder_output = self.input_module(input, length, cache_last_channel, cache_last_time)
encoder_output = self.input_module(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
if isinstance(encoder_output, tuple):
decoder_input = encoder_output[0]
else:
decoder_input = encoder_output
if hasattr(self.output_module, 'forward_for_export'):
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module.forward_for_export(decoder_input)
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
# TODO: update this part to support full encoder/decoder export
ret = encoder_output
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module(decoder_input)
ret = self.output_module(encoder_output=decoder_input)
else:
ret = self.output_module(encoder_output=decoder_input)
if cache_last_channel is None and cache_last_time is None:
pass
else:
if isinstance(ret, tuple):
ret = (ret[0], encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
else:
# TODO: update this part to support full encoder/decoder export
ret = encoder_output
ret = (ret, encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32)

@property
def disabled_deployment_input_names(self):
return self.encoder.disabled_deployment_input_names

@property
def disabled_deployment_output_names(self):
return self.encoder.disabled_deployment_output_names
Loading

0 comments on commit d1f1c9e

Please sign in to comment.