Skip to content

Commit

Permalink
More streaming conformer export fixes (NVIDIA#6567) (NVIDIA#6578)
Browse files Browse the repository at this point in the history
Signed-off-by: Greg Clark <[email protected]>
Co-authored-by: Greg Clark <[email protected]>
Co-authored-by: Vahid Noroozi <[email protected]>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
3 people authored and hsiehjackson committed Jun 2, 2023
1 parent 04c1b72 commit c13ffb9
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
29 changes: 29 additions & 0 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,19 @@ def input_types(self):
}
)

@property
def input_types_for_export(self):
"""Returns definitions of module input ports."""
return OrderedDict(
{
"audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
"cache_last_channel": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True),
"cache_last_time": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True),
"cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True),
}
)

@property
def output_types(self):
"""Returns definitions of module output ports."""
Expand All @@ -196,6 +209,19 @@ def output_types(self):
}
)

@property
def output_types_for_export(self):
"""Returns definitions of module output ports."""
return OrderedDict(
{
"outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
"cache_last_channel_next": NeuralType(('B', 'D', 'T', 'D'), ChannelType(), optional=True),
"cache_last_time_next": NeuralType(('B', 'D', 'D', 'T'), ChannelType(), optional=True),
"cache_last_channel_next_len": NeuralType(tuple('B'), LengthsType(), optional=True),
}
)

@property
def disabled_deployment_input_names(self):
if not self.export_cache_support:
Expand Down Expand Up @@ -489,6 +515,8 @@ def forward_for_export(
rets = self.streaming_post_process(rets, keep_all_outputs=False)
if len(rets) == 2:
return rets
elif rets[2] is None and rets[3] is None and rets[4] is None:
return (rets[0], rets[1])
else:
return (
rets[0],
Expand Down Expand Up @@ -549,6 +577,7 @@ def forward_internal(
audio_signal = self.pre_encode(audio_signal)
else:
audio_signal, length = self.pre_encode(x=audio_signal, lengths=length)
length = length.to(torch.int64)
# self.streaming_cfg is set by setup_streaming_cfg(), called in the init
if self.streaming_cfg.drop_extra_pre_encoded > 0 and cache_last_channel is not None:
audio_signal = audio_signal[:, self.streaming_cfg.drop_extra_pre_encoded :, :]
Expand Down
16 changes: 12 additions & 4 deletions nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def _export(
elif format == ExportFormat.ONNX:
# dynamic axis is a mapping from input/output_name => list of "dynamic" indices
if dynamic_axes is None:
dynamic_axes = get_dynamic_axes(self.input_module.input_types, input_names)
dynamic_axes.update(get_dynamic_axes(self.output_module.output_types, output_names))
dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names)
dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names))
torch.onnx.export(
jitted_model,
input_example,
Expand Down Expand Up @@ -273,11 +273,19 @@ def _export_teardown(self):

@property
def input_names(self):
return get_io_names(self.input_module.input_types, self.disabled_deployment_input_names)
return get_io_names(self.input_module.input_types_for_export, self.disabled_deployment_input_names)

@property
def output_names(self):
return get_io_names(self.output_module.output_types, self.disabled_deployment_output_names)
return get_io_names(self.output_module.output_types_for_export, self.disabled_deployment_output_names)

@property
def input_types_for_export(self):
return self.input_types

@property
def output_types_for_export(self):
return self.output_types

def get_export_subnet(self, subnet=None):
"""
Expand Down
1 change: 1 addition & 0 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import torch
from pytorch_lightning import Trainer

import nemo
from nemo.core import ModelPT
from nemo.core.classes import Exportable
from nemo.core.config.pytorch_lightning import TrainerConfig
Expand Down

0 comments on commit c13ffb9

Please sign in to comment.