Skip to content

Commit

Permalink
Test onnx streaming conformer ctc WER
Browse files Browse the repository at this point in the history
  • Loading branch information
messiaen committed Jan 23, 2023
1 parent d993e83 commit accc444
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import time
from argparse import ArgumentParser

import onnxruntime
import torch
from omegaconf import open_dict

Expand All @@ -86,6 +87,10 @@
from nemo.utils import logging


def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


def extract_transcriptions(hyps):
"""
The transcribed_texts returned by CTC and RNNT models are different.
Expand All @@ -102,12 +107,134 @@ def extract_transcriptions(hyps):

def calc_drop_extra_pre_encoded(asr_model, step_num):
# for the first step there is no need to drop any tokens after the downsampling as no caching is being used
# return asr_model.encoder.streaming_cfg.drop_extra_pre_encoded
if step_num == 0:
return 0
else:
return asr_model.encoder.streaming_cfg.drop_extra_pre_encoded


def perform_streaming_unified_onnx(asr_model, sess, streaming_buffer, compare_vs_offline=False, debug_mode=False):
asr_model.encoder.export_streaming_support = True
batch_size = len(streaming_buffer.streams_length)
final_offline_tran = None

# TODO compute without model
cache_last_channel, cache_last_time = asr_model.encoder.get_initial_cache_state(batch_size=batch_size)

previous_hypotheses = None
streaming_buffer_iter = iter(streaming_buffer)
pred_out_stream = None
all_preds = None
all_preds_lens = None
transcribed_texts = []

# drop_extra_pre_encoded = torch.full((batch_size,), 0, dtype=torch.int32, device=asr_model.device)
for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter):
ort_inputs = {}
ort_inputs[sess.get_inputs()[0].name] = to_numpy(chunk_audio)
ort_inputs[sess.get_inputs()[1].name] = to_numpy(chunk_lengths)
ort_inputs[sess.get_inputs()[2].name] = to_numpy(cache_last_channel)
ort_inputs[sess.get_inputs()[3].name] = to_numpy(cache_last_time)

ort_outputs = sess.run(None, ort_inputs)
log_probs = torch.from_numpy(ort_outputs[0])
encoded_len = torch.from_numpy(ort_outputs[1])
cache_last_channel = torch.from_numpy(ort_outputs[2])
cache_last_time = torch.from_numpy(ort_outputs[3])

predictions_tensor = log_probs.argmax(dim=-1, keepdim=False)
if step_num == 0:
all_preds = predictions_tensor
all_preds_lens = encoded_len
else:
all_preds = torch.cat((all_preds, predictions_tensor), dim=1)
all_preds_lens = all_preds_lens + encoded_len

for preds_idx, preds in enumerate(all_preds):
preds_cur = all_preds[preds_idx, : all_preds_lens[preds_idx]]

# TODO: make decoding more efficient by avoiding the decoding process from the beginning
decoded_out = asr_model.decoding.ctc_decoder_predictions_tensor(
decoder_outputs=preds_cur.unsqueeze(0),
decoder_lengths=all_preds_lens[preds_idx : preds_idx + 1],
return_hypotheses=False,
)
# print("DEBUG decoded_out", decoded_out)
transcribed_texts.append(decoded_out[0][0])

# if debug_mode:
# logging.info(f"Streaming transcriptions: {extract_transcriptions(transcribed_texts)}")

final_streaming_tran = extract_transcriptions(transcribed_texts)
logging.info(f"Final streaming transcriptions: {final_streaming_tran}")

return final_streaming_tran, final_offline_tran


def perform_streaming_unified(asr_model, streaming_buffer, compare_vs_offline=False, debug_mode=False):
batch_size = len(streaming_buffer.streams_length)
final_offline_tran = None
asr_model.encoder.export_streaming_support = True

cache_last_channel, cache_last_time = asr_model.encoder.get_initial_cache_state(batch_size=batch_size)

previous_hypotheses = None
streaming_buffer_iter = iter(streaming_buffer)
pred_out_stream = None
all_preds = None
all_preds_lens = None
transcribed_texts = []

# drop_extra_pre_encoded = torch.full((batch_size,), 0, dtype=torch.int32, device=asr_model.device)
for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter):
# print("DEBUG audio", chunk_audio.size())
with torch.inference_mode():
with autocast():
# keep_all_outputs needs to be True for the last step of streaming when model is trained with att_context_style=regular
# otherwise the last outputs would get dropped

with torch.no_grad():
(log_probs, encoded_len, cache_last_channel, cache_last_time,) = asr_model.forward_for_export(
input=chunk_audio,
length=chunk_lengths,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
)
# print("DEBUG log_probs.size()", log_probs.size())
predictions_tensor = log_probs.argmax(dim=-1, keepdim=False)
# print("DEBUG preds", predictions_tensor)
# print("DEBUG encoded_len", encoded_len)
if step_num == 0:
all_preds = predictions_tensor
all_preds_lens = encoded_len
else:
all_preds = torch.cat((all_preds, predictions_tensor), dim=1)
all_preds_lens = all_preds_lens + encoded_len

# print("DEBUG all_pred", all_preds)
# print("DEBUG all_pred", all_preds.size())
for preds_idx, preds in enumerate(all_preds):
preds_cur = all_preds[preds_idx, : all_preds_lens[preds_idx]]

# TODO: make decoding more efficient by avoiding the decoding process from the beginning
decoded_out = asr_model.decoding.ctc_decoder_predictions_tensor(
decoder_outputs=preds_cur.unsqueeze(0),
decoder_lengths=all_preds_lens[preds_idx : preds_idx + 1],
return_hypotheses=False,
)
# print("DEBUG decoded_out", decoded_out)
transcribed_texts.append(decoded_out[0][0])

# if debug_mode:
# logging.info(f"Streaming transcriptions: {extract_transcriptions(transcribed_texts)}")

final_streaming_tran = extract_transcriptions(transcribed_texts)
logging.info(f"Final streaming transcriptions: {final_streaming_tran}")

return final_streaming_tran, final_offline_tran


def perform_streaming(asr_model, streaming_buffer, compare_vs_offline=False, debug_mode=False):
batch_size = len(streaming_buffer.streams_length)
if compare_vs_offline:
Expand Down Expand Up @@ -287,6 +414,13 @@ def autocast():
asr_model = asr_model.to(args.device)
asr_model.eval()

sess = onnxruntime.InferenceSession("tmp.onnx", providers=['CUDAExecutionProvider'])

for input in sess.get_inputs():
print("ort input", input.name)
for out in sess.get_outputs():
print("ort output", out.name)

# chunk_size is set automatically for models trained for streaming. For models trained for offline mode with full context, we need to pass the chunk_size explicitly.
if args.chunk_size > 0:
if args.shift_size < 0:
Expand Down Expand Up @@ -346,8 +480,11 @@ def autocast():

if (sample_idx + 1) % args.batch_size == 0 or sample_idx == len(samples) - 1:
logging.info(f"Starting to stream samples {sample_idx - len(streaming_buffer) + 1} to {sample_idx}...")
streaming_tran, offline_tran = perform_streaming(
# streaming_tran, offline_tran = perform_streaming(
# streaming_tran, offline_tran = perform_streaming_unified(
streaming_tran, offline_tran = perform_streaming_unified_onnx(
asr_model=asr_model,
sess=sess,
streaming_buffer=streaming_buffer,
compare_vs_offline=args.compare_vs_offline,
debug_mode=args.debug_mode,
Expand Down
37 changes: 25 additions & 12 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import AccessMixin
from nemo.core.utils.neural_type_utils import get_io_names
from nemo.utils import logging, model_utils
from nemo.utils.cast_utils import cast_all

Expand Down Expand Up @@ -155,7 +156,17 @@ def input_module(self):
def output_module(self):
return self.decoder

def forward_for_export(self, input, length=None, cache_last_channel=None, cache_last_time=None, drop_extra_pre_encoded=None):
@property
def output_names(self):
if hasattr(self.input_module, 'export_cache_support') and self.input_module.export_cache_support:
out_types = self.output_module.output_types
in_types = self.input_module.output_types
otypes = {n: t for (n, t) in list(out_types.items())[:1]}
for (n, t) in list(in_types.items())[1:]:
otypes[n] = t
return get_io_names(otypes, self.disabled_deployment_output_names)

def forward_for_export(self, input, length=None, cache_last_channel=None, cache_last_time=None):
"""
This forward is used when we need to export the model to ONNX format.
Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models.
Expand All @@ -177,13 +188,16 @@ def forward_for_export(self, input, length=None, cache_last_channel=None, cache_
encoder_output = self.input_module.forward_for_export(input, length)
else:
encoder_output = self.input_module.forward_for_export(
input, length, cache_last_channel, cache_last_time, drop_extra_pre_encoded
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
)
else:
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module(input, length)
else:
encoder_output = self.input_module(input, length, cache_last_channel, cache_last_time, drop_extra_pre_encoded)
encoder_output = self.input_module(input, length, cache_last_channel, cache_last_time)
if isinstance(encoder_output, tuple):
decoder_input = encoder_output[0]
else:
Expand All @@ -192,20 +206,19 @@ def forward_for_export(self, input, length=None, cache_last_channel=None, cache_
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module.forward_for_export(decoder_input)
else:
# TODO: update this part to support full encoder/decoder export
#ret = encoder_output
ret = self.output_module.forward_for_export(decoder_input)
else:
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module(decoder_input)
else:
# TODO: update this part to support full encoder/decoder export
#ret = encoder_output
ret = self.output_module(decoder_input)
if isinstance(ret, tuple):
print("output tuple len", len(ret))
ret = (ret[0], encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
print("output tuple len 2", len(ret))
ret = self.output_module(encoder_output=decoder_input)
if cache_last_channel is None and cache_last_time is None:
pass
else:
if isinstance(ret, tuple):
ret = (ret[0], encoder_output[1], encoder_output[2], encoder_output[3])
else:
ret = (ret, encoder_output[1], encoder_output[2], encoder_output[3])
return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32)

@property
Expand Down
60 changes: 23 additions & 37 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def input_example(self, max_batch=1, max_dim=256):
if hasattr(self, 'export_cache_support') and self.export_cache_support:
cache_last_channel = torch.randn(self.n_layers, max_batch, max_dim, self.d_model).to(dev)
cache_last_time = torch.randn(self.n_layers, max_batch, self.d_model, self.conv_context_size[0]).to(dev)
drop_extra_pre_encoded = torch.randint(1, max_dim, (max_batch,)).to(dev)
drop_extra_pre_encoded = torch.clamp(drop_extra_pre_encoded - input_example_length, min=0)
all_input_example = tuple([input_example, input_example_length, cache_last_channel, cache_last_time, drop_extra_pre_encoded])
all_input_example = tuple([input_example, input_example_length, cache_last_channel, cache_last_time])
else:
all_input_example = tuple([input_example, input_example_length])

Expand All @@ -137,7 +135,6 @@ def input_types(self):
"length": NeuralType(tuple('B'), LengthsType()),
"cache_last_channel": NeuralType(('D', 'B', 'T', 'D'), ChannelType(), optional=True),
"cache_last_time": NeuralType(('D', 'B', 'D', 'T'), ChannelType(), optional=True),
"drop_extra_pre_encoded": NeuralType(tuple('B'), LengthsType(), optional=True),
}
)

Expand All @@ -150,7 +147,6 @@ def output_types(self):
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
"cache_last_channel_next": NeuralType(('D', 'B', 'T', 'D'), ChannelType(), optional=True),
"cache_last_time_next": NeuralType(('D', 'B', 'D', 'T'), ChannelType(), optional=True),
"drop_extra_pre_encoded_next": NeuralType(tuple('B'), LengthsType(), optional=True),
}
)

Expand Down Expand Up @@ -365,6 +361,7 @@ def __init__(

self.setup_streaming_params()
self.export_cache_support = False
self.export_streaming_support = False

def update_max_seq_length(self, seq_length: int, device):
# Find global max audio length across all nodes
Expand Down Expand Up @@ -412,7 +409,14 @@ def set_max_audio_length(self, max_audio_length):
self.att_mask = None

@typecheck()
def forward_for_export(self, audio_signal=None, length=None, cache_last_channel=None, cache_last_time=None, drop_extra_pre_encoded=None):
def forward_for_export(self, audio_signal=None, length=None, cache_last_channel=None, cache_last_time=None):
if not self.export_streaming_support:
return self.forward_internal(
audio_signal=audio_signal,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
)
if self.streaming_cfg is None:
self.setup_streaming_params()

Expand All @@ -421,14 +425,13 @@ def forward_for_export(self, audio_signal=None, length=None, cache_last_channel=
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
drop_extra_pre_encoded=drop_extra_pre_encoded,
)

if len(encoder_output) == 2:
encoded, encoded_len = encoder_output
cache_last_channel_next = cache_last_time_next = None
else:
encoded, encoded_len, cache_last_channel_next, cache_last_time_next, drop_extra_pre_encoded_next = encoder_output
encoded, encoded_len, cache_last_channel_next, cache_last_time_next = encoder_output

if cache_last_channel_next is not None and self.streaming_cfg.last_channel_cache_size >= 0:
if self.streaming_cfg.last_channel_cache_size > 0:
Expand All @@ -437,24 +440,23 @@ def forward_for_export(self, audio_signal=None, length=None, cache_last_channel=
]
else:
cache_last_channel_next = cache_last_channel_next[:, :, 0:0, :]
if True:
encoded = encoded[:, :, : self.streaming_cfg.valid_out_len]
encoded_len = torch.clamp(encoded_len, max=self.streaming_cfg.valid_out_len)

return encoded, encoded_len, cache_last_channel_next, cache_last_time_next, drop_extra_pre_encoded_next
encoded = encoded[:, :, : self.streaming_cfg.valid_out_len]
encoded_len = torch.clamp(encoded_len.long(), max=self.streaming_cfg.valid_out_len).int()

return encoded, encoded_len, cache_last_channel_next, cache_last_time_next

@typecheck()
def forward(self, audio_signal, length, cache_last_channel=None, cache_last_time=None, drop_extra_pre_encoded=None):
def forward(self, audio_signal, length, cache_last_channel=None, cache_last_time=None):
return self.forward_internal(
audio_signal=audio_signal,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
drop_extra_pre_encoded=drop_extra_pre_encoded,
)

@typecheck()
def forward_internal(self, audio_signal, length, cache_last_channel=None, cache_last_time=None, drop_extra_pre_encoded=None):
def forward_internal(self, audio_signal, length, cache_last_channel=None, cache_last_time=None):
self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device)
max_audio_length: int = audio_signal.size(-1)

Expand All @@ -474,24 +476,11 @@ def forward_internal(self, audio_signal, length, cache_last_channel=None, cache_
else:
audio_signal, length = self.pre_encode(x=audio_signal, lengths=length)
# self.streaming_cfg is set by setup_streaming_cfg(), called in the init
if drop_extra_pre_encoded is None:
if self.streaming_cfg.drop_extra_pre_encoded > 0 and cache_last_channel is not None:
audio_signal = audio_signal[:, self.streaming_cfg.drop_extra_pre_encoded :, :]
# TODO: find a better solution
length = (length - self.streaming_cfg.drop_extra_pre_encoded)
length = torch.clamp(length, min=torch.tensor(0, dtype=torch.int32, device=length.device, requires_grad=False))
else:
# this is only for inference
audio_signal = audio_signal.transpose(1, 2)
gather_idx = ((torch.meshgrid(
torch.arange(audio_signal.size(0)),
torch.arange(audio_signal.size(1)),
torch.arange(audio_signal.size(2)))[2].transpose(0,2).to(audio_signal.device) - drop_extra_pre_encoded) % audio_signal.size(2)).transpose(0,2)
audio_signal = torch.gather(audio_signal, 2, gather_idx).transpose(1,2)
length = (length - drop_extra_pre_encoded)
print("DEBUG length = ", length)
length = torch.clamp(length, min=torch.tensor(0, dtype=torch.int32, device=length.device, requires_grad=False))

if self.streaming_cfg.drop_extra_pre_encoded > 0 and cache_last_channel is not None:
audio_signal = audio_signal[:, self.streaming_cfg.drop_extra_pre_encoded :, :]
# TODO: find a better solution
length = length - self.streaming_cfg.drop_extra_pre_encoded
length = torch.clamp(length.long(), min=0).int()

max_audio_length = audio_signal.size(1)

Expand Down Expand Up @@ -556,10 +545,7 @@ def forward_internal(self, audio_signal, length, cache_last_channel=None, cache_
audio_signal = torch.transpose(audio_signal, 1, 2)

if cache_last_channel is not None:
if drop_extra_pre_encoded is not None:
return audio_signal, length, cache_last_channel_next, cache_last_time_next, torch.full(length.size(), self.streaming_cfg.drop_extra_pre_encoded, dtype=length.dtype, requires_grad=False)
else:
return audio_signal, length, cache_last_channel_next, cache_last_time_next
return audio_signal, length, cache_last_channel_next, cache_last_time_next
else:
return audio_signal, length

Expand Down
Loading

0 comments on commit accc444

Please sign in to comment.