Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from optimum.exporters.base import ExporterConfig
from optimum.exporters.onnx.constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, Seq2SeqModelPatcher
from optimum.exporters.onnx.model_patcher import ModelPatcher, Seq2SeqModelPatcher
from optimum.utils import DEFAULT_DUMMY_SHAPES, DummyInputGenerator, DummySeq2SeqPastKeyValuesGenerator, logging
from optimum.utils.doc import add_dynamic_docstring
from optimum.utils.import_utils import (
Expand Down Expand Up @@ -428,7 +428,6 @@ class OnnxConfigWithPast(OnnxConfig, ABC):

PAD_ATTENTION_MASK_TO_PAST: bool = False
SUPPORTS_PAST: bool = True
_MODEL_PATCHER = DecoderModelPatcher

def __init__(
self,
Expand Down
24 changes: 20 additions & 4 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
SentenceTransformersCLIPPatcher,
SentenceTransformersTransformerPatcher,
SpeechT5ModelPatcher,
VisionEncoderDecoderPatcher,
VitPoseModelPatcher,
)
from optimum.exporters.tasks import TasksManager
Expand Down Expand Up @@ -515,13 +514,32 @@ class Gemma2OnnxConfig(GemmaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")


@register_tasks_manager_onnx("gemma3_text", *COMMON_TEXT_GENERATION_TASKS)
class Gemma3TextOnnxConfig(GemmaOnnxConfig):
# Gemma 3 was added in transformers v4.50 using HybridCache
# DynamicCache support was added since v4.53
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")


# we still don't support gemma3 for multimodal feature-extraction(-with-past) and image-text-to-text(-with-past) tasks
@register_tasks_manager_onnx("gemma3", *COMMON_TEXT_GENERATION_TASKS, "text-classification")
@register_tasks_manager_onnx("gemma3_text", *COMMON_TEXT_GENERATION_TASKS, "text-classification")
class Gemma3OnnxConfig(GemmaOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
head_dim="text_config.head_dim",
vocab_size="text_config.vocab_size",
hidden_size="text_config.hidden_size",
num_layers="text_config.num_hidden_layers",
num_attention_heads="text_config.num_attention_heads",
num_key_value_heads="text_config.num_key_value_heads",
allow_new=True,
)
# Gemma 3 was added in transformers v4.50 using HybridCache
# DynamicCache support was added since v4.53
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")

def __init__(self, config: PretrainedConfig, task: str = "feature-extraction", **kwargs):
super().__init__(config, task, **kwargs)


@register_tasks_manager_onnx("gpt_oss", *COMMON_TEXT_GENERATION_TASKS)
class GPTOssOnnxConfig(GemmaOnnxConfig):
Expand Down Expand Up @@ -2443,9 +2461,7 @@ class TrOCROnnxConfig(TextSeq2SeqOnnxConfig):
)
class VisionEncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig

DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyVisionEncoderDecoderPastKeyValuesGenerator)
_MODEL_PATCHER = VisionEncoderDecoderPatcher

@property
def inputs(self) -> dict[str, dict[int, str]]:
Expand Down
44 changes: 13 additions & 31 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,24 @@ def __init__(

self.real_config = config
self.model_kwargs = model_kwargs if model_kwargs is not None else {}
allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past
allow_past_in_outputs = getattr(self.real_config, "use_past", False)

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

# Transformers doesn't always respect the config.use_cache attribute
# there are even cases where setting use_cache to true in every config and
# subconfig of a model still doesn't enable past_key_values in the outputs (gemma3)
# Explicitly setting the use_cache argument of the forward method seems to be the most reliable way
if "use_cache" in signature.parameters:
use_cache_index = list(signature.parameters.keys()).index("use_cache")
if use_cache_index < len(args):
args[use_cache_index] = allow_past_in_outputs
elif "use_cache" in kwargs:
kwargs["use_cache"] = allow_past_in_outputs

if is_transformers_version(">=", "4.48"):
if "past_key_values" in signature.parameters:
pkv_index = list(signature.parameters.keys()).index("past_key_values")
Expand Down Expand Up @@ -654,13 +665,6 @@ def __init__(

allow_past_in_outputs = getattr(self.real_config, "use_past", False)

# sometimes the text_config/decoder is set to False
if allow_past_in_outputs:
if hasattr(model.config, "text_config"):
model.config.text_config.use_cache = True
elif hasattr(model.config, "decoder"):
model.config.decoder.use_cache = True

# Re-use the patched forward method from the parent class
self.super_patched_forward = self.patched_forward

Expand Down Expand Up @@ -717,28 +721,6 @@ def __exit__(self, exc_type, exc_value, traceback):
self._model.set_attention_type("block_sparse")


class VisionEncoderDecoderPatcher(Seq2SeqModelPatcher):
def __init__(
self,
config: OnnxConfig,
model: PreTrainedModel,
model_kwargs: dict[str, Any] | None = None,
):
super().__init__(config, model, model_kwargs)
use_cache = hasattr(self.real_config, "use_past")

if config._behavior == "decoder" and model.config.decoder.model_type == "trocr" and use_cache:
model.decoder.model.decoder.config.use_cache = True


class DecoderModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)


class MgpstrModelPatcher(ModelPatcher):
def __init__(
self,
Expand Down Expand Up @@ -1291,7 +1273,7 @@ def qwen3_moe_forward_patched(self, hidden_states: torch.Tensor) -> torch.Tensor
return final_hidden_states, router_logits


class Qwen3MoeModelPatcher(DecoderModelPatcher):
class Qwen3MoeModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()

Expand Down
7 changes: 5 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,10 @@ def __init__(
"To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`."
)

if self.config.model_type in {"gemma", "gemma3", "gemma3_text", "gpt_oss", "nemotron"}:
if self.config.model_type in {"gemma", "gemma3_text", "gpt_oss", "nemotron"}:
self.embed_size_per_head = self.config.head_dim
elif self.config.model_type == "gemma3":
self.embed_size_per_head = self.config.text_config.head_dim
elif self.old_gpt_bigcode_modeling:
# (before v4.54) GPT BigCode fuses keys and values in one tensor, doubling the head dimension
self.embed_size_per_head = self.config.hidden_size // self.config.num_attention_heads * 2
Expand All @@ -202,7 +204,6 @@ def __init__(
"deepseek_v3",
"cohere",
"gemma",
"gemma3",
"gemma3_text",
"glm",
"granite",
Expand All @@ -218,6 +219,8 @@ def __init__(
"stablelm",
}:
self.num_key_value_heads = self.config.num_key_value_heads
elif self.config.model_type == "gemma3":
self.num_key_value_heads = self.config.text_config.num_key_value_heads
elif self.config.model_type == "falcon":
if self.config.new_decoder_architecture or not self.config.multi_query:
self.num_key_value_heads = self.config.num_kv_heads
Expand Down
8 changes: 3 additions & 5 deletions tests/exporters/onnx/test_export_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@
from optimum.exporters.error_utils import MinimumVersionError
from optimum.exporters.onnx import main_export
from optimum.exporters.tasks import TasksManager
from optimum.onnxruntime import (
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
ONNX_ENCODER_NAME,
)
from optimum.onnxruntime import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME
from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_sentence_transformers, require_timm


Expand Down Expand Up @@ -142,13 +138,15 @@ def _get_models_to_test(export_models_dict: dict, library_name: str):
# TODO: encoder-decoder auto-infers text3text-generation, but it uses bert as decoder and does not support past key values
# TODO: vision-encoder-decoder tiny models have wrong labels on the Hub
# TODO: unispeech-sat tiny models have wrong labels on the Hub
# TODO: we still don't support gemma3 for image-text-to-text(-with-past) tasks
if model_type not in [
"segformer",
"xlm-roberta",
"perceiver",
"encoder-decoder",
"vision-encoder-decoder",
"unispeech-sat",
"gemma3",
]:
models_to_test.append(
(f"{model_type}_no_task_{model_name}", model_type, model_name, "auto", "default", False, False)
Expand Down
Loading