diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index f5b856a7cf..859cbc0af0 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -405,6 +405,9 @@ def main_export( # Saving the model config and preprocessor as this is needed sometimes. model.config.save_pretrained(output) + generation_config = getattr(model, "generation_config", None) + if generation_config is not None: + generation_config.save_pretrained(output) maybe_save_preprocessors(model_name_or_path, output) if task == "stable-diffusion": diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 26333153ae..bee6ec1c01 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -18,7 +18,7 @@ import logging import shutil -from abc import ABC, abstractmethod +from abc import ABC, ABCMeta, abstractmethod from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -26,7 +26,13 @@ import numpy as np import torch from huggingface_hub import hf_hub_download -from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, GenerationConfig +from transformers import ( + AutoModelForSeq2SeqLM, + AutoModelForSpeechSeq2Seq, + AutoModelForVision2Seq, + GenerationConfig, + WhisperForConditionalGeneration, +) from transformers.file_utils import add_start_docstrings_to_model_forward from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput @@ -569,6 +575,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]): dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copyfile(src_path, dst_path) + self.generation_config.save_pretrained(save_directory) + @classmethod def _from_pretrained( cls, @@ -1046,6 +1054,50 @@ def can_generate(self): """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" return True + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: "PretrainedConfig", + **kwargs, + ): + if "WhisperForConditionalGeneration" in config.architectures: + return _ORTModelForWhisper._from_pretrained(model_id, config, **kwargs) + else: + return super()._from_pretrained(model_id, config, **kwargs) + + +class MetaClassRemoveParentsAndReorder(ABCMeta): + def mro(cls): + """ + Avoids inheritting from PreTrainedModel, nn.Module, ModuleUtilsMixin, PushToHubMixin, + and put GenerationMixin at the end of the MRO + """ + top_inheritance_index = ORTModelForSpeechSeq2Seq.__mro__.index(GenerationMixin) + return ( + (cls,) + + ORTModelForSpeechSeq2Seq.__mro__[:top_inheritance_index] + + (WhisperForConditionalGeneration,) + + ORTModelForSpeechSeq2Seq.__mro__[top_inheritance_index:] + ) + + +class _ORTModelForWhisper( + ORTModelForSpeechSeq2Seq, WhisperForConditionalGeneration, metaclass=MetaClassRemoveParentsAndReorder +): + """ + Whisper implements its own generate() method. + """ + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: "PretrainedConfig", + **kwargs, + ): + return super(ORTModelForSpeechSeq2Seq, cls)._from_pretrained(model_id, config, **kwargs) + class ORTModelForVision2Seq(ORTModelForConditionalGeneration, GenerationMixin): """ diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 0c00c7e664..080b348753 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -3457,6 +3457,13 @@ def test_pipeline_speech_recognition(self, test_name: str, model_arch: str, use_ self.assertEqual(pipe.device, onnx_model.device) self.assertIsInstance(outputs["text"], str) + if model_arch == "whisper": + outputs = pipe(data, return_timestamps=True) + self.assertTrue("chunks" in outputs) + + outputs = pipe(data, return_timestamps=False) + self.assertTrue("chunks" not in outputs) + gc.collect() @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}))