diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index f857ea3eaf..8811c3f872 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -120,6 +120,7 @@ class OnnxConfig(ExportConfig, ABC): "seq2seq-lm": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}), "sequence-classification": OrderedDict({"logits": {0: "batch_size"}}), "token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), } def __init__( @@ -206,6 +207,17 @@ def is_torch_support_available(self) -> bool: return TORCH_VERSION >= self.MIN_TORCH_VERSION return False + @property + def torch_to_onnx_input_map(self) -> Mapping[str, str]: + """ + Dictionary of keys to update the ONNX input name for export. Override the function when + the dummy input names and the exported ONNX input names need to be different. + + Returns: + `Mapping[str, str]`: A dictionary specifying the dummy input name to exported ONNX input name map. + """ + return {} + def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int, str]]: """ Re-orders the inputs using the model forward pass signature. @@ -218,6 +230,7 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int, `Mapping[str, Mappingp[int, str]]`: The properly ordered inputs. """ inputs = self.inputs + ordered_inputs = {} sig = inspect.signature(model.forward) for param in sig.parameters: @@ -229,6 +242,7 @@ def ordered_inputs(self, model: "PreTrainedModel") -> Mapping[str, Mapping[int, # TODO: figure out a smart way of re-ordering potential nested structures. # to_insert = sorted(to_insert, key=lambda t: t[0]) for name, dynamic_axes in to_insert: + name = self.torch_to_onnx_input_map.get(name, name) ordered_inputs[name] = dynamic_axes return ordered_inputs diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 9869cbfc57..4e72ccae8c 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -17,6 +17,7 @@ from typing import Mapping from ...utils import ( + DummyAudioInputGenerator, DummyBboxInputGenerator, DummyDecoderTextInputGenerator, DummyPastKeyValuesGenerator, @@ -99,3 +100,15 @@ class VisionOnnxConfig(OnnxConfig): class TextAndVisionOnnxConfig(OnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator, DummyBboxInputGenerator) + + +class AudioOnnxConfig(OnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyAudioInputGenerator,) + + +class TextAndAudioOnnxConfig(Seq2SeqOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyAudioInputGenerator, + DummyDecoderTextInputGenerator, + DummySeq2SeqPastKeyValuesGenerator, + ) diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 6c92576c2d..284fce2e29 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -125,6 +125,7 @@ def export_pytorch( else: # Export can work with named args but the dict containing named args has to be the last element of the args # tuple. + onnx_export( model, (dummy_inputs,), diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index cdf4f70123..98b4fdecfc 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -21,16 +21,26 @@ from ...utils import ( DummyDecoderTextInputGenerator, DummyPastKeyValuesGenerator, + DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, DummyTextInputGenerator, DummyVisionInputGenerator, + NormalizedConfig, NormalizedSeq2SeqConfig, NormalizedTextAndVisionConfig, NormalizedTextConfig, NormalizedVisionConfig, ) from .base import OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast -from .config import DecoderOnnxConfig, EncoderOnnxConfig, Seq2SeqOnnxConfig, TextAndVisionOnnxConfig, VisionOnnxConfig +from .config import ( + AudioOnnxConfig, + DecoderOnnxConfig, + EncoderOnnxConfig, + Seq2SeqOnnxConfig, + TextAndAudioOnnxConfig, + TextAndVisionOnnxConfig, + VisionOnnxConfig, +) if TYPE_CHECKING: @@ -626,3 +636,72 @@ def generate_dummy_inputs(self, framework: str = "pt"): self.is_generating_dummy_inputs = True dummy_inputs[self.inputs_name] = dummy_inputs.pop(specialized_inputs_name) return dummy_inputs + + +class WhisperOnnxConfig(TextAndAudioOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig + ATOL_FOR_VALIDATION = 1e-3 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "input_features": {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"}, + } + if self.use_past: + common_inputs["decoder_input_ids"] = {0: "batch_size"} + else: + common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} + + if self.use_past: + self.add_past_key_values(common_inputs, direction="inputs") + + return common_inputs + + +class SpeechSeq2SeqEncoderOnnxConfig(AudioOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedConfig + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return { + "input_features": {0: "batch_size", 1: "feature_size", 2: "encoder_sequence_length"}, + } + + +class SpeechSeq2SeqDecoderOnnxConfig(Seq2SeqOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig + + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummySeq2SeqDecoderTextInputGenerator, + DummyDecoderTextInputGenerator, + DummySeq2SeqPastKeyValuesGenerator, + ) + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "decoder_input_ids": {0: "batch_size", 1: "past_decoder_sequence_length + sequence_length"}, + "encoder_outputs": {0: "batch_size", 1: "encoder_sequence_length"}, + } + + if self.use_past: + self.add_past_key_values(common_inputs, direction="inputs") + + return common_inputs + + @property + def torch_to_onnx_input_map(self) -> Mapping[str, str]: + return {"decoder_input_ids": "input_ids", "encoder_outputs": "encoder_hidden_states"} + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + common_outputs = super().outputs + self.add_past_key_values(common_outputs, direction="outputs") + return common_outputs + + @property + def values_override(self) -> Optional[Mapping[str, Any]]: + if hasattr(self._config, "use_cache"): + return {"use_cache": True} + + return None diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index e79ce64999..55c58f3e44 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -47,6 +47,7 @@ AutoModelForSemanticSegmentation, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, + AutoModelForSpeechSeq2Seq, AutoModelForTokenClassification, ) if is_tf_available(): @@ -122,6 +123,7 @@ class TasksManager: "image-segmentation": AutoModelForImageSegmentation, "masked-im": AutoModelForMaskedImageModeling, "semantic-segmentation": AutoModelForSemanticSegmentation, + "speech2seq-lm": AutoModelForSpeechSeq2Seq, } if is_tf_available(): _TASKS_TO_TF_AUTOMODELS = { @@ -506,6 +508,13 @@ class TasksManager: onnx="T5OnnxConfig", ), "vit": supported_tasks_mapping("default", "image-classification", "masked-im", onnx="ViTOnnxConfig"), + "whisper": supported_tasks_mapping( + "default", + "default-with-past", + "speech2seq-lm", + "speech2seq-lm-with-past", + onnx="WhisperOnnxConfig", + ), "xlm": supported_tasks_mapping( "default", "masked-lm", diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index 5ea4a2c1c4..fad82d4ffc 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -37,7 +37,7 @@ "ORTModelForSequenceClassification", "ORTModelForTokenClassification", ], - "modeling_seq2seq": ["ORTModelForSeq2SeqLM"], + "modeling_seq2seq": ["ORTModelForSeq2SeqLM", "ORTModelForSpeechSeq2Seq"], "optimization": ["ORTOptimizer"], "quantization": ["ORTQuantizer"], "trainer": ["ORTTrainer"], @@ -68,7 +68,7 @@ ORTModelForSequenceClassification, ORTModelForTokenClassification, ) - from .modeling_seq2seq import ORTModelForSeq2SeqLM + from .modeling_seq2seq import ORTModelForSeq2SeqLM, ORTModelForSpeechSeq2Seq from .optimization import ORTOptimizer from .quantization import ORTQuantizer from .trainer import ORTTrainer diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 49bd7301d9..c68ab5f829 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -16,12 +16,12 @@ import os import shutil from pathlib import Path -from typing import Any, DefaultDict, Dict, List, Mapping, Optional, Set, Tuple, Union +from typing import Dict, Mapping, Optional, Tuple, Union import numpy as np import torch import transformers -from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoTokenizer, PretrainedConfig from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, default_cache_path from transformers.generation_utils import GenerationMixin from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput @@ -31,6 +31,7 @@ import onnxruntime from huggingface_hub import HfApi, hf_hub_download +from ..exporters.onnx.model_configs import SpeechSeq2SeqDecoderOnnxConfig, SpeechSeq2SeqEncoderOnnxConfig from ..onnx.configuration import DecoderOnnxConfig, EncoderOnnxConfig from ..onnx.modeling_seq2seq import _DecoderWithLMhead from .io_binding import TypeHelper @@ -40,7 +41,6 @@ ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME, ORTConfigManager, - _is_gpu_available, get_device_for_provider, get_provider_for_device, parse_device, @@ -75,7 +75,7 @@ if the device is CUDA, otherwise defaults to `False`. """ -ENCODER_INPUTS_DOCSTRING = r""" +SEQ2SEQ_ENCODER_INPUTS_DOCSTRING = r""" Arguments: input_ids (`torch.LongTensor`): Indices of input sequence tokens in the vocabulary of shape `(batch_size, encoder_sequence_length)`. @@ -84,6 +84,12 @@ `(batch_size, encoder_sequence_length)`. Mask values selected in `[0, 1]`. """ +WHISPER_ENCODER_INPUTS_DOCSTRING = r""" + Arguments: + input_features (`torch.FloatTensor`): + Mel features extracted from the raw speech waveform. `(batch_size, feature_size, encoder_sequence_length)`. +""" + DECODER_INPUTS_DOCSTRING = r""" Arguments: @@ -118,7 +124,25 @@ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. """ + +SPEECH_SEQ2SEQ_ONNX_MODEL_DOCSTRING = r""" + Arguments: + input_features (`torch.FloatTensor`): + Mel features extracted from the raw speech waveform. + `(batch_size, feature_size, encoder_sequence_length)`. + decoder_input_ids (`torch.LongTensor`): + Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, decoder_sequence_length)`. + encoder_outputs (`torch.FloatTensor`): + The encoder `last_hidden_state` of shape `(batch_size, encoder_sequence_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor), *optional*)` + Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding. + The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape + `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)` and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. +""" + _TOKENIZER_FOR_DOC = "AutoTokenizer" +_PROCESSOR_FOR_DOC = "AutoProcessor" TRANSLATION_EXAMPLE = r""" Example of text generation: @@ -152,6 +176,41 @@ """ +AUTOMATIC_SPEECH_RECOGNITION_EXAMPLE = r""" + Example of text generation: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> from datasets import load_dataset + + >>> processor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor.feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") + + >>> gen_tokens = model.generate(inputs=inputs.input_features) + >>> outputs = processor.tokenizer.batch_decode(gen_tokens) + ``` + + Example using `transformers.pipeline`: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> from datasets import load_dataset + + >>> processor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + >>> speech_recognition = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor) + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> pred = speech_recognition(ds[0]["audio"]["array"]) + ``` +""" + + @add_start_docstrings( """ Sequence-to-sequence model with a language modeling head for ONNX Runtime inference. @@ -161,8 +220,6 @@ class ORTModelForConditionalGeneration(ORTModel): # Used in from_transformers to export model to onnxORTEncoder base_model_prefix = "onnx_model" - export_feature = "seq2seq-lm" - auto_model_class = AutoModelForSeq2SeqLM def __init__( self, @@ -186,13 +243,15 @@ def __init__( ) self.use_io_binding = False - self.encoder = ORTEncoder( + self.encoder = self._initialize_encoder( session=encoder_session, config=self.config, device=self._device, use_io_binding=self.use_io_binding ) self.decoder = ORTDecoder( session=decoder_session, config=self.config, device=self._device, use_io_binding=self.use_io_binding ) + self.use_cache = decoder_with_past_session is not None + # If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs # will be enabled self.decoder_with_past = ( @@ -205,6 +264,7 @@ def __init__( if self.use_cache else None ) + self.encoder_file_name = kwargs.get("last_encoder_model_name", ONNX_ENCODER_NAME) self.decoder_file_name = kwargs.get("last_decoder_model_name", ONNX_DECODER_NAME) self.decoder_file_with_past_name = kwargs.get("last_decoder_with_past_model_name", ONNX_DECODER_WITH_PAST_NAME) @@ -475,43 +535,78 @@ def _from_transformers( _, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=cls.export_feature) onnx_config = model_onnx_config(model.config) onnx_opset = onnx_config.default_onnx_opset - onnx_config_encoder = EncoderOnnxConfig(model.config, task="default") - onnx_config_decoder = DecoderOnnxConfig(model.config, task=cls.export_feature, use_past=False) - onnx_config_decoder_with_past = DecoderOnnxConfig(model.config, task=cls.export_feature, use_past=True) - # Extract the encoder for ONNX export + # Extract the encoder and decoder for ONNX export encoder = model.get_encoder() + decoder = model.get_decoder() + # Concatenate the decoder with the language model head for ONNX export decoder_with_lm_head = _DecoderWithLMhead(model) - # Export the encoder - export( - preprocessor=tokenizer, - model=encoder, - config=onnx_config_encoder, - opset=onnx_opset, - output=save_dir.joinpath(ONNX_ENCODER_NAME), - ) + # Get the encoder and decoder ONNX configs + onnx_config_encoder = cls.get_encoder_onnx_config(encoder.config) + onnx_config_decoder = cls.get_decoder_onnx_config(decoder.config, cls.export_feature, use_past=False) + if use_cache: + onnx_config_decoder_with_past = cls.get_decoder_onnx_config( + decoder.config, cls.export_feature, use_past=True + ) - # Export the decoder without the past key values - export( - preprocessor=tokenizer, - model=decoder_with_lm_head, - config=onnx_config_decoder, - opset=onnx_opset, - output=save_dir.joinpath(ONNX_DECODER_NAME), - ) + if config.model_type == "whisper": + from ..exporters.onnx.convert import export as export_optimum - # Export the decoder with the past key values - if use_cache: + # Export the encoder + export_optimum( + encoder, + onnx_config_encoder, + onnx_opset, + save_dir.joinpath(ONNX_ENCODER_NAME), + ) + + # Export the decoder without the past key values + export_optimum( + model, + onnx_config_decoder, + onnx_opset, + save_dir.joinpath(ONNX_DECODER_NAME), + ) + + # Export the decoder with the past key values + if use_cache: + export_optimum( + model, + onnx_config_decoder_with_past, + onnx_opset, + save_dir.joinpath(ONNX_DECODER_WITH_PAST_NAME), + ) + else: + # Export the encoder + export( + preprocessor=tokenizer, + model=encoder, + config=onnx_config_encoder, + opset=onnx_opset, + output=save_dir.joinpath(ONNX_ENCODER_NAME), + ) + + # Export the decoder without the past key values export( preprocessor=tokenizer, model=decoder_with_lm_head, - config=onnx_config_decoder_with_past, + config=onnx_config_decoder, opset=onnx_opset, - output=save_dir.joinpath(ONNX_DECODER_WITH_PAST_NAME), + output=save_dir.joinpath(ONNX_DECODER_NAME), ) + # Export the decoder with the past key values + if use_cache: + export( + preprocessor=tokenizer, + model=decoder_with_lm_head, + config=onnx_config_decoder_with_past, + opset=onnx_opset, + output=save_dir.joinpath(ONNX_DECODER_WITH_PAST_NAME), + ) + kwargs["config"] = model.config return cls._from_pretrained(save_dir, **kwargs) @@ -558,13 +653,14 @@ def __init__( config: transformers.PretrainedConfig, device: torch.device, use_io_binding: bool = True, + main_input_name: str = "input_ids", **kwargs ): self.session = session self.config = config self._device = device self.use_io_binding = use_io_binding - self.main_input_name = "input_ids" + self.main_input_name = main_input_name self.normalized_config = ORTConfigManager.get_normalized_config_class(self.config.model_type)(self.config) self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} @@ -628,7 +724,7 @@ def prepare_io_binding( return io_binding, output_shapes, output_buffers - @add_start_docstrings_to_model_forward(ENCODER_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(SEQ2SEQ_ENCODER_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor, @@ -665,6 +761,77 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) +class ORTEncoderForWhisper(ORTEncoder): + """ + Encoder model for ONNX Runtime inference for Whisper model. + + Arguments: + session (`onnxruntime.InferenceSession`): + The ONNX Runtime inference session associated to the encoder. + """ + + def prepare_io_binding( + self, + input_features: torch.FloatTensor = None, + ): + io_binding = self.session.io_binding() + + # bind input ids + io_binding.bind_input( + "input_features", + input_features.device.type, + self._device.index, + self.name_to_np_type["input_features"], + tuple(input_features.shape), + input_features.data_ptr(), + ) + + # bind logits + output_shape, output_buffer = self.prepare_output_buffer( + batch_size=input_features.size(0), + sequence_length=input_features.size(2) // 2, + ) + io_binding.bind_output( + "last_hidden_state", + output_buffer.device.type, + self._device.index, + self.name_to_np_type["last_hidden_state"], + output_shape, + output_buffer.data_ptr(), + ) + output_shapes = {"last_hidden_state": output_shape} + output_buffers = {"last_hidden_state": output_buffer} + + return io_binding, output_shapes, output_buffers + + @add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING) + def forward( + self, + input_features: torch.FloatTensor, + **kwargs, + ) -> BaseModelOutput: + if self._device.type == "cuda" and self.use_io_binding: + io_binding, output_shapes, output_buffers = self.prepare_io_binding(input_features) + + # run inference with binding & synchronize in case of multiple CUDA streams + io_binding.synchronize_inputs() + self.session.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + # converts output to namedtuple for pipelines post-processing + return BaseModelOutput( + last_hidden_state=output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) + ) + else: + onnx_inputs = {"input_features": input_features.cpu().detach().numpy()} + + # Run inference + outputs = self.session.run(None, onnx_inputs) + last_hidden_state = torch.from_numpy(outputs[self.output_names["last_hidden_state"]]).to(self._device) + + return BaseModelOutput(last_hidden_state=last_hidden_state) + + class ORTDecoder: """ Decoder model with a language modeling head on top for ONNX Runtime inference. @@ -691,8 +858,12 @@ def __init__( self.session_outputs = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} self.session_input_names = list(self.session_inputs.keys()) self.session_output_names = list(self.session_outputs.keys()) - self.key_value_input_names = [key for key in self.session_input_names if "key_values" in key] - self.key_value_output_names = [key for key in self.session_output_names if "key_values" in key] + self.key_value_input_names = [ + key for key in self.session_input_names if ("key_values" in key or ".key" in key or ".value" in key) + ] + self.key_value_output_names = [ + key for key in self.session_output_names if ("key_values" in key or ".key" in key or ".value" in key) + ] self.name_to_np_type = TypeHelper.get_io_numpy_type_map(self.session) if self.use_io_binding else None def prepare_output_buffer( @@ -715,7 +886,7 @@ def prepare_output_buffer( elif output_name == "logits": output_shape = (batch_size, sequence_length, self.normalized_config.vocab_size) output_buffer = torch.empty(np.prod(output_shape), dtype=torch_type, device=self._device).contiguous() - elif "key_values" in output_name: + elif "key_values" in output_name or ".key" in output_name or ".value" in output_name: num_attention_heads = self.normalized_config.num_attention_heads hidden_size = self.normalized_config.hidden_size embed_size_per_head = hidden_size // num_attention_heads @@ -752,15 +923,16 @@ def prepare_io_binding( ) # bind encoder attention mask - encoder_attention_mask = encoder_attention_mask.contiguous() - io_binding.bind_input( - "encoder_attention_mask", - encoder_attention_mask.device.type, - self._device.index, - self.name_to_np_type["encoder_attention_mask"], - list(encoder_attention_mask.size()), - encoder_attention_mask.data_ptr(), - ) + if "encoder_attention_mask" in self.session_input_names: + encoder_attention_mask = encoder_attention_mask.contiguous() + io_binding.bind_input( + "encoder_attention_mask", + encoder_attention_mask.device.type, + self._device.index, + self.name_to_np_type["encoder_attention_mask"], + list(encoder_attention_mask.size()), + encoder_attention_mask.data_ptr(), + ) # bind encoder hidden states if "encoder_hidden_states" in self.session_input_names: @@ -910,7 +1082,7 @@ def forward( # self-attention layer and 2 to the cross-attention layer) past_key_values = tuple() for name in self.session_output_names: - if "key_values" in name: + if "key_values" in name or ".key" in name or ".value" in name: past_key_values += (output_buffers[name].view(output_shapes[name]),) # Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and @@ -926,9 +1098,12 @@ def forward( else: onnx_inputs = { "input_ids": input_ids.cpu().detach().numpy(), - "encoder_attention_mask": encoder_attention_mask.cpu().detach().numpy(), } + # Add the encoder_attention_mask inputs when needed + if "encoder_attention_mask" in self.session_input_names: + onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy() + # Add the encoder_hidden_states inputs when needed if "encoder_hidden_states" in self.session_input_names: onnx_inputs["encoder_hidden_states"] = encoder_hidden_states.cpu().detach().numpy() @@ -949,7 +1124,7 @@ def forward( past_key_values = tuple( torch.from_numpy(outputs[self.session_outputs[key]]).to(self._device) for key in self.session_output_names - if "key_values" in key + if "key_values" in key or ".key" in key or ".value" in key ) # Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and @@ -974,9 +1149,35 @@ class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin): Sequence-to-sequence model with a language modeling head for ONNX Runtime inference. """ + export_feature = "seq2seq-lm" + auto_model_class = AutoModelForSeq2SeqLM + main_input_name = "input_ids" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.main_input_name = "input_ids" + + def _initialize_encoder( + self, + session: onnxruntime.InferenceSession, + config: transformers.PretrainedConfig, + device: torch.device, + use_io_binding: bool = True, + ) -> ORTEncoder: + return ORTEncoder( + session=session, + config=config, + device=device, + use_io_binding=use_io_binding, + main_input_name=self.main_input_name, + ) + + def get_encoder_onnx_config(encoder_config: PretrainedConfig) -> EncoderOnnxConfig: + return EncoderOnnxConfig(encoder_config, task="default") + + def get_decoder_onnx_config( + decoder_config: PretrainedConfig, export_feature: str, use_past: bool = False + ) -> DecoderOnnxConfig: + return DecoderOnnxConfig(decoder_config, export_feature, use_past=use_past) @add_start_docstrings_to_model_forward( SEQ2SEQ_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length") @@ -1061,3 +1262,128 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past + + +class ORTModelForSpeechSeq2Seq(ORTModelForConditionalGeneration, GenerationMixin): + """ + Speech Sequence-to-sequence model with a language modeling head for ONNX Runtime inference. + """ + + export_feature = "speech2seq-lm" + auto_model_class = AutoModelForSpeechSeq2Seq + main_input_name = "input_features" + + _MODEL_TYPE_TO_ORTENCODER = { + "whisper": ORTEncoderForWhisper, + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _initialize_encoder( + self, + session: onnxruntime.InferenceSession, + config: transformers.PretrainedConfig, + device: torch.device, + use_io_binding: bool = True, + ) -> ORTEncoder: + if config.model_type not in self._MODEL_TYPE_TO_ORTENCODER: + raise KeyError( + f"{config.model_type} is not supported yet. " + f"Only {list(self._MODEL_TYPE_TO_ORTENCODER.keys())} are supported. " + f"If you want to support {config.model_type} please propose a PR or open up an issue." + ) + return self._MODEL_TYPE_TO_ORTENCODER[config.model_type]( + session=session, + config=config, + device=device, + use_io_binding=use_io_binding, + main_input_name=self.main_input_name, + ) + + def get_encoder_onnx_config(encoder_config: PretrainedConfig) -> SpeechSeq2SeqEncoderOnnxConfig: + return SpeechSeq2SeqEncoderOnnxConfig(encoder_config, task="default") + + def get_decoder_onnx_config( + decoder_config: PretrainedConfig, export_feature: str, use_past: bool = False + ) -> SpeechSeq2SeqDecoderOnnxConfig: + return SpeechSeq2SeqDecoderOnnxConfig(decoder_config, export_feature, use_past=use_past) + + @add_start_docstrings_to_model_forward( + SPEECH_SEQ2SEQ_ONNX_MODEL_DOCSTRING.format("batch_size, feature_size, sequence_length") + + AUTOMATIC_SPEECH_RECOGNITION_EXAMPLE.format( + processor_class=_PROCESSOR_FOR_DOC, + model_class="ORTModelForSpeechSeq2Seq", + checkpoint="optimum/whisper-tiny.en", + ) + ) + def forward( + self, + input_features: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Seq2SeqLMOutput: + + # Encode if needed : first prediction pass + if encoder_outputs is None: + encoder_outputs = self.encoder(input_features=input_features) + + # Decode + if past_key_values is None or self.decoder_with_past is None: + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_outputs.last_hidden_state, + labels=labels, + ) + else: + decoder_outputs = self.decoder_with_past( + input_ids=decoder_input_ids[:, -1:], # Cut decoder_input_ids if past is used + past_key_values=past_key_values, + encoder_hidden_states=encoder_outputs.last_hidden_state, + labels=labels, + ) + + return Seq2SeqLMOutput( + loss=decoder_outputs.get("loss", None), + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ) -> Dict: + + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def get_encoder(self) -> ORTEncoder: + return self.encoder + + # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache + @staticmethod + def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: + reordered_past = () + for layer_past in past: + # Cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index c804545d48..9443641f2d 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -58,6 +58,9 @@ def _is_gpu_available(): num_attention_heads="num_heads", hidden_size="d_model", ) +WhisperLikeNormalizedTextConfig = NormalizedTextConfig.with_args( + hidden_size="d_model", +) class ORTConfigManager: @@ -91,6 +94,7 @@ class ORTConfigManager: "m2m_100": (BartLikeNormalizedTextConfig, "bart"), "roberta": (NormalizedTextConfig, "bert"), "t5": (T5LikeNormalizedTextConfig, "t5"), + "whisper": (WhisperLikeNormalizedTextConfig, "whisper"), "xlm-roberta": (NormalizedTextConfig, "bert"), } diff --git a/optimum/pipelines.py b/optimum/pipelines.py index 381571b7cd..a1d30a9fcb 100644 --- a/optimum/pipelines.py +++ b/optimum/pipelines.py @@ -1,6 +1,7 @@ from typing import Any, Optional, Union from transformers import ( + AutomaticSpeechRecognitionPipeline, FeatureExtractionPipeline, ImageClassificationPipeline, Pipeline, @@ -31,6 +32,7 @@ ORTModelForQuestionAnswering, ORTModelForSeq2SeqLM, ORTModelForSequenceClassification, + ORTModelForSpeechSeq2Seq, ORTModelForTokenClassification, ) from .onnxruntime.modeling_ort import ORTModel @@ -96,6 +98,12 @@ "default": "t5-small", "type": "text", }, + "automatic-speech-recognition": { + "impl": AutomaticSpeechRecognitionPipeline, + "class": (ORTModelForSpeechSeq2Seq,) if is_onnxruntime_available() else (), + "default": "openai/whisper-tiny.en", + "type": "multimodal", + }, } NO_FEATURE_EXTRACTOR_TASKS = set() @@ -105,8 +113,8 @@ NO_FEATURE_EXTRACTOR_TASKS.add(task) elif values["type"] == "image": NO_TOKENIZER_TASKS.add(task) - else: - raise ValueError(f"Supported types are 'text' and 'image', got {values['type']}") + elif values["type"] != "multimodal": + raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}") def load_bettertransformer( diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 774e408a16..8e8cc167e6 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -69,9 +69,11 @@ def check_if_pytorch_greater_112(): from .input_generators import ( # noqa + DummyAudioInputGenerator, DummyBboxInputGenerator, DummyDecoderTextInputGenerator, DummyPastKeyValuesGenerator, + DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, DummyTextInputGenerator, DummyVisionInputGenerator, diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index d3325567ba..35c7f1c249 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -243,6 +243,45 @@ class DummyDecoderTextInputGenerator(DummyTextInputGenerator): ) +class DummySeq2SeqDecoderTextInputGenerator(DummyDecoderTextInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "decoder_input_ids", + "decoder_attention_mask", + "encoder_outputs", + ) + + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = 2, + sequence_length: int = 16, + num_choices: int = 4, + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + random_num_choices_range: Optional[Tuple[int, int]] = None, + ): + super().__init__( + task, + normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + num_choices=num_choices, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + random_num_choices_range=random_num_choices_range, + ) + + self.hidden_size = normalized_config.hidden_size + + def generate(self, input_name: str, framework: str = "pt"): + if input_name == "encoder_outputs": + shape = (self.batch_size, self.sequence_length, self.hidden_size) + return (self.random_float_tensor(shape, min_value=0, max_value=1, framework=framework), None, None) + + return super().generate(input_name, framework=framework) + + class DummyPastKeyValuesGenerator(DummyInputGenerator): SUPPORTED_INPUT_NAMES = ("past_key_values",) @@ -409,3 +448,31 @@ def generate(self, input_name: str, framework: str = "pt"): return self.random_int_tensor(shape, max_value=1, framework=framework) shape = [self.batch_size, self.num_channels, self.height, self.width] return self.random_float_tensor(shape, framework=framework) + + +class DummyAudioInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ("input_features", "input_values") + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = 2, + feature_size: int = 80, + nb_max_frames: int = 3000, + sequence_length: int = 16000, + ): + self.task = task + + self.feature_size = feature_size + self.nb_max_frames = nb_max_frames + self.batch_size = batch_size + self.sequence_length = sequence_length + + def generate(self, input_name: str, framework: str = "pt"): + shape = [self.batch_size, self.sequence_length] + if input_name == "input_values": + self.random_float_tensor(shape, min_value=-1, max_value=1, framework=framework) + + shape = [self.batch_size, self.feature_size, self.nb_max_frames] + return self.random_float_tensor(shape, min_value=-1, max_value=1, framework=framework) diff --git a/tests/exporters/test_onnx_export.py b/tests/exporters/test_onnx_export.py index 3a579a9f33..81f5063c1b 100644 --- a/tests/exporters/test_onnx_export.py +++ b/tests/exporters/test_onnx_export.py @@ -96,6 +96,7 @@ ("bigbird-pegasus", "hf-internal-testing/tiny-random-bigbird_pegasus"), # Not using google/long-t5-local-base because it takes too much time for testing. ("longt5", "hf-internal-testing/tiny-random-longt5"), + ("whisper", "openai/whisper-tiny.en"), } diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index d9ea8cef2b..2a43df3511 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -18,6 +18,7 @@ import tempfile import unittest +import numpy as np import pytest import torch from PIL import Image @@ -29,6 +30,7 @@ AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, + AutoModelForSpeechSeq2Seq, AutoModelForTokenClassification, PretrainedConfig, set_seed, @@ -53,6 +55,7 @@ ORTModelForQuestionAnswering, ORTModelForSeq2SeqLM, ORTModelForSequenceClassification, + ORTModelForSpeechSeq2Seq, ORTModelForTokenClassification, ) from optimum.onnxruntime.modeling_ort import ORTModel @@ -81,6 +84,7 @@ "bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus", "gpt2": "hf-internal-testing/tiny-random-gpt2", "vit": "hf-internal-testing/tiny-random-vit", + "whisper": "openai/whisper-tiny.en", } SEED = 42 @@ -1445,6 +1449,175 @@ def test_compare_generation_to_io_binding(self, model_arch): gc.collect() +class ORTModelForSpeechSeq2SeqIntegrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ("whisper",) + + def _generate_random_audio_data(self): + np.random.seed(10) + t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False) + # generate pure sine wave at 220 Hz + audio_data = 0.5 * np.sin(2 * np.pi * 220 * t) + return audio_data + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = ORTModelForSpeechSeq2Seq.from_pretrained(MODEL_NAMES["bert"], from_transformers=True) + + self.assertIn("Unrecognized configuration class", str(context.exception)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_generate_utils(self, model_arch): + model_id = MODEL_NAMES[model_arch] + model = ORTModelForSpeechSeq2Seq.from_pretrained(model_id, from_transformers=True) + processor = get_preprocessor(model_id) + + data = self._generate_random_audio_data() + features = processor.feature_extractor(data, return_tensors="pt") + + outputs = model.generate(inputs=features["input_features"]) + res = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True) + self.assertIsInstance(res[0], str) + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + onnx_model = ORTModelForSpeechSeq2Seq.from_pretrained(model_id, from_transformers=True) + + self.assertIsInstance(onnx_model.encoder, ORTEncoder) + self.assertIsInstance(onnx_model.decoder, ORTDecoder) + self.assertIsInstance(onnx_model.decoder_with_past, ORTDecoder) + self.assertIsInstance(onnx_model.config, PretrainedConfig) + + set_seed(SEED) + transformers_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id) + processor = get_preprocessor(model_id) + + data = self._generate_random_audio_data() + features = processor.feature_extractor(data, return_tensors="pt") + + decoder_start_token_id = transformers_model.config.decoder_start_token_id + decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} + onnx_outputs = onnx_model(**features, **decoder_inputs) + + self.assertTrue("logits" in onnx_outputs) + self.assertIsInstance(onnx_outputs.logits, torch.Tensor) + + with torch.no_grad(): + transformers_outputs = transformers_model(**features, **decoder_inputs) + # Compare tensor outputs + self.assertTrue(torch.allclose(onnx_outputs.logits, transformers_outputs.logits, atol=1e-4)) + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline_speech_recognition(self, model_arch): + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForSpeechSeq2Seq.from_pretrained(model_id, from_transformers=True) + processor = get_preprocessor(model_id) + + # Speech recogition generation + pipe = pipeline( + "automatic-speech-recognition", + model=onnx_model, + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + ) + data = self._generate_random_audio_data() + outputs = pipe(data) + self.assertEqual(pipe.device, onnx_model.device) + self.assertIsInstance(outputs["text"], str) + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_torch_gpu + def test_pipeline_on_gpu(self, model_arch): + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForSpeechSeq2Seq.from_pretrained(model_id, from_transformers=True) + processor = get_preprocessor(model_id) + pipe = pipeline( + "automatic-speech-recognition", + model=onnx_model, + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + ) + + data = self._generate_random_audio_data() + outputs = pipe(data) + + # check model device + self.assertEqual(pipe.model.device.type.lower(), "cuda") + # compare model output class + self.assertTrue(isinstance(outputs["text"], str)) + + def test_compare_with_and_without_past_key_values_model_outputs(self): + model_id = MODEL_NAMES["whisper"] + processor = get_preprocessor(model_id) + + data = self._generate_random_audio_data() + features = processor.feature_extractor(data, return_tensors="pt") + + model_with_pkv = ORTModelForSpeechSeq2Seq.from_pretrained(model_id, from_transformers=True, use_cache=True) + outputs_model_with_pkv = model_with_pkv.generate(**features) + model_without_pkv = ORTModelForSpeechSeq2Seq.from_pretrained(model_id, from_transformers=True, use_cache=False) + outputs_model_without_pkv = model_without_pkv.generate(**features) + + self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_torch_gpu + def test_compare_to_io_binding(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + onnx_model = ORTModelForSpeechSeq2Seq.from_pretrained(model_id, from_transformers=True, use_io_binding=False) + set_seed(SEED) + io_model = ORTModelForSpeechSeq2Seq.from_pretrained(model_id, from_transformers=True, use_io_binding=True) + + processor = get_preprocessor(model_id) + + data = self._generate_random_audio_data() + features = processor.feature_extractor([data] * 2, return_tensors="pt") + + decoder_start_token_id = onnx_model.config.decoder_start_token_id + decoder_inputs = {"decoder_input_ids": torch.ones((2, 1), dtype=torch.long) * decoder_start_token_id} + + onnx_outputs = onnx_model(**features, **decoder_inputs) + io_outputs = io_model(**features, **decoder_inputs) + + self.assertTrue("logits" in io_outputs) + self.assertIsInstance(io_outputs.logits, torch.Tensor) + + # compare tensor outputs + self.assertTrue(torch.equal(onnx_outputs.logits, io_outputs.logits)) + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_torch_gpu + def test_compare_generation_to_io_binding(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + onnx_model = ORTModelForSpeechSeq2Seq.from_pretrained(model_id, from_transformers=True, use_io_binding=False) + set_seed(SEED) + io_model = ORTModelForSpeechSeq2Seq.from_pretrained(model_id, from_transformers=True, use_io_binding=True) + + processor = get_preprocessor(model_id) + + data = self._generate_random_audio_data() + features = processor.feature_extractor(data, return_tensors="pt") + + onnx_outputs = onnx_model.generate(**features, num_beams=5) + io_outputs = io_model.generate(**features, num_beams=5) + + # compare tensor outputs + self.assertTrue(torch.equal(onnx_outputs, io_outputs)) + + gc.collect() + + class ORTModelForCustomTasksIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = { "sbert": "optimum/sbert-all-MiniLM-L6-with-pooler",