Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,9 @@ def main_export(

# Saving the model config and preprocessor as this is needed sometimes.
model.config.save_pretrained(output)
generation_config = getattr(model, "generation_config", None)
if generation_config is not None:
generation_config.save_pretrained(output)
maybe_save_preprocessors(model_name_or_path, output)

if task == "stable-diffusion":
Expand Down
56 changes: 54 additions & 2 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@

import logging
import shutil
from abc import ABC, abstractmethod
from abc import ABC, ABCMeta, abstractmethod
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, GenerationConfig
from transformers import (
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq,
GenerationConfig,
WhisperForConditionalGeneration,
)
from transformers.file_utils import add_start_docstrings_to_model_forward
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput

Expand Down Expand Up @@ -569,6 +575,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(src_path, dst_path)

self.generation_config.save_pretrained(save_directory)

@classmethod
def _from_pretrained(
cls,
Expand Down Expand Up @@ -1046,6 +1054,50 @@ def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True

@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
**kwargs,
):
if "WhisperForConditionalGeneration" in config.architectures:
return _ORTModelForWhisper._from_pretrained(model_id, config, **kwargs)
else:
return super()._from_pretrained(model_id, config, **kwargs)


class MetaClassRemoveParentsAndReorder(ABCMeta):
def mro(cls):
"""
Avoids inheritting from PreTrainedModel, nn.Module, ModuleUtilsMixin, PushToHubMixin,
and put GenerationMixin at the end of the MRO
"""
top_inheritance_index = ORTModelForSpeechSeq2Seq.__mro__.index(GenerationMixin)
return (
(cls,)
+ ORTModelForSpeechSeq2Seq.__mro__[:top_inheritance_index]
+ (WhisperForConditionalGeneration,)
+ ORTModelForSpeechSeq2Seq.__mro__[top_inheritance_index:]
)


class _ORTModelForWhisper(
ORTModelForSpeechSeq2Seq, WhisperForConditionalGeneration, metaclass=MetaClassRemoveParentsAndReorder
):
"""
Whisper implements its own generate() method.
"""

@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
**kwargs,
):
return super(ORTModelForSpeechSeq2Seq, cls)._from_pretrained(model_id, config, **kwargs)


class ORTModelForVision2Seq(ORTModelForConditionalGeneration, GenerationMixin):
"""
Expand Down
7 changes: 7 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3457,6 +3457,13 @@ def test_pipeline_speech_recognition(self, test_name: str, model_arch: str, use_
self.assertEqual(pipe.device, onnx_model.device)
self.assertIsInstance(outputs["text"], str)

if model_arch == "whisper":
outputs = pipe(data, return_timestamps=True)
self.assertTrue("chunks" in outputs)

outputs = pipe(data, return_timestamps=False)
self.assertTrue("chunks" not in outputs)

gc.collect()

@parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}))
Expand Down