Skip to content
Merged
14 changes: 14 additions & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class OnnxConfig(ExportConfig, ABC):
"seq2seq-lm": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}),
"sequence-classification": OrderedDict({"logits": {0: "batch_size"}}),
"token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the "official" name?

We could take:

  • automatic-speech-recognition to match the pipelines
  • speech2text

@lewtun wdty?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea was to partially align with the underlying autoclass, but I agree automatic-speech-recognition would be more intuitive.

In general (not for this PR), I think we should take the opportunity to align more closely with the Hub tasks, e.g. seq2seq-lm could also be text2text-generation right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright then I guess we can keep speech2seq-lm for now since the other names are aligned to the AutoClass, and maybe change that (if needed) for all the tasks in another PR.

}

def __init__(
Expand Down Expand Up @@ -206,6 +207,17 @@ def is_torch_support_available(self) -> bool:
return TORCH_VERSION >= self.MIN_TORCH_VERSION
return False

@property
def torch_to_onnx_input_map(self) -> Mapping[str, str]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make it clear that it is needed when the dummy input names and the exported input names do not match.

Copy link
Contributor Author

@mht-sharma mht-sharma Nov 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the doctoring

"""
Dictionary of keys to update the ONNX input name for export. Override the function when
the dummy input names and the exported ONNX input names need to be different.

Returns:
`Mapping[str, str]`: A dictionary specifying the dummy input name to exported ONNX input name map.
"""
return {}

def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int, str]]:
"""
Re-orders the inputs using the model forward pass signature.
Expand All @@ -218,6 +230,7 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int,
`Mapping[str, Mappingp[int, str]]`: The properly ordered inputs.
"""
inputs = self.inputs

ordered_inputs = {}
sig = inspect.signature(model.forward)
for param in sig.parameters:
Expand All @@ -229,6 +242,7 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int,
# TODO: figure out a smart way of re-ordering potential nested structures.
# to_insert = sorted(to_insert, key=lambda t: t[0])
for name, dynamic_axes in to_insert:
name = self.torch_to_onnx_input_map.get(name, name)
ordered_inputs[name] = dynamic_axes
return ordered_inputs

Expand Down
13 changes: 13 additions & 0 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Mapping

from ...utils import (
DummyAudioInputGenerator,
DummyBboxInputGenerator,
DummyDecoderTextInputGenerator,
DummyPastKeyValuesGenerator,
Expand Down Expand Up @@ -99,3 +100,15 @@ class VisionOnnxConfig(OnnxConfig):

class TextAndVisionOnnxConfig(OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator, DummyBboxInputGenerator)


class AudioOnnxConfig(OnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyAudioInputGenerator,)


class TextAndAudioOnnxConfig(Seq2SeqOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyAudioInputGenerator,
DummyDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
)
1 change: 1 addition & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def export_pytorch(
else:
# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.

onnx_export(
model,
(dummy_inputs,),
Expand Down
81 changes: 80 additions & 1 deletion optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,26 @@
from ...utils import (
DummyDecoderTextInputGenerator,
DummyPastKeyValuesGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
DummyTextInputGenerator,
DummyVisionInputGenerator,
NormalizedConfig,
NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedVisionConfig,
)
from .base import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from .config import DecoderOnnxConfig, EncoderOnnxConfig, Seq2SeqOnnxConfig, TextAndVisionOnnxConfig, VisionOnnxConfig
from .config import (
AudioOnnxConfig,
DecoderOnnxConfig,
EncoderOnnxConfig,
Seq2SeqOnnxConfig,
TextAndAudioOnnxConfig,
TextAndVisionOnnxConfig,
VisionOnnxConfig,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -626,3 +636,72 @@ def generate_dummy_inputs(self, framework: str = "pt"):
self.is_generating_dummy_inputs = True
dummy_inputs[self.inputs_name] = dummy_inputs.pop(specialized_inputs_name)
return dummy_inputs


class WhisperOnnxConfig(TextAndAudioOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig
ATOL_FOR_VALIDATION = 1e-3

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = {
"input_features": {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"},
}
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}

if self.use_past:
self.add_past_key_values(common_inputs, direction="inputs")

return common_inputs


class SpeechSeq2SeqEncoderOnnxConfig(AudioOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedConfig

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return {
"input_features": {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"},
}


class SpeechSeq2SeqDecoderOnnxConfig(Seq2SeqOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig

DUMMY_INPUT_GENERATOR_CLASSES = (
DummySeq2SeqDecoderTextInputGenerator,
DummyDecoderTextInputGenerator,
DummySeq2SeqPastKeyValuesGenerator,
)

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = {
"decoder_input_ids": {0: "batch_size", 1: "past_decoder_sequence_length + sequence_length"},
"encoder_outputs": {0: "batch_size", 1: "encoder_sequence_length"},
}

if self.use_past:
self.add_past_key_values(common_inputs, direction="inputs")

return common_inputs

@property
def torch_to_onnx_input_map(self) -> Mapping[str, str]:
return {"decoder_input_ids": "input_ids", "encoder_outputs": "encoder_hidden_states"}

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
common_outputs = super().outputs
self.add_past_key_values(common_outputs, direction="outputs")
return common_outputs

@property
def values_override(self) -> Optional[Mapping[str, Any]]:
if hasattr(self._config, "use_cache"):
return {"use_cache": True}

return None
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
AutoModelForSemanticSegmentation,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTokenClassification,
)
if is_tf_available():
Expand Down Expand Up @@ -122,6 +123,7 @@ class TasksManager:
"image-segmentation": AutoModelForImageSegmentation,
"masked-im": AutoModelForMaskedImageModeling,
"semantic-segmentation": AutoModelForSemanticSegmentation,
"speech2seq-lm": AutoModelForSpeechSeq2Seq,
}
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
Expand Down Expand Up @@ -506,6 +508,13 @@ class TasksManager:
onnx="T5OnnxConfig",
),
"vit": supported_tasks_mapping("default", "image-classification", "masked-im", onnx="ViTOnnxConfig"),
"whisper": supported_tasks_mapping(
"default",
"default-with-past",
"speech2seq-lm",
"speech2seq-lm-with-past",
onnx="WhisperOnnxConfig",
),
"xlm": supported_tasks_mapping(
"default",
"masked-lm",
Expand Down
4 changes: 2 additions & 2 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"ORTModelForSequenceClassification",
"ORTModelForTokenClassification",
],
"modeling_seq2seq": ["ORTModelForSeq2SeqLM"],
"modeling_seq2seq": ["ORTModelForSeq2SeqLM", "ORTModelForSpeechSeq2Seq"],
"optimization": ["ORTOptimizer"],
"quantization": ["ORTQuantizer"],
"trainer": ["ORTTrainer"],
Expand Down Expand Up @@ -68,7 +68,7 @@
ORTModelForSequenceClassification,
ORTModelForTokenClassification,
)
from .modeling_seq2seq import ORTModelForSeq2SeqLM
from .modeling_seq2seq import ORTModelForSeq2SeqLM, ORTModelForSpeechSeq2Seq
from .optimization import ORTOptimizer
from .quantization import ORTQuantizer
from .trainer import ORTTrainer
Expand Down
Loading