Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- ChineseCLIP
- CLIP
- CodeGen
- Cohere
- ConvBert
- ConvNext
- ConvNextV2
Expand Down Expand Up @@ -58,6 +59,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- OPT
- Granite
- GroupVit
- Helium
- Hiera
- Hubert
- IBert
Expand Down Expand Up @@ -116,6 +118,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- SpeechT5
- Splinter
- SqueezeBert
- StableLM
- Swin
- SwinV2
- T5
Expand Down
18 changes: 18 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from optimum.exporters.onnx.constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from optimum.exporters.onnx.model_patcher import (
CLIPModelPatcher,
CohereModelPatcher,
FluxTransformerModelPatcher,
MgpstrModelPatcher,
MusicgenModelPatcher,
Expand Down Expand Up @@ -440,11 +441,28 @@ class ArceeOnnxConfig(LlamaOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA


@register_tasks_manager_onnx("cohere", *COMMON_TEXT_GENERATION_TASKS)
class CohereOnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.38.0")
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
_MODEL_PATCHER = CohereModelPatcher


@register_tasks_manager_onnx("helium", *COMMON_TEXT_GENERATION_TASKS)
class HeliumOnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.49.0")


@register_tasks_manager_onnx("smollm3", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"])
class SmolLM3OnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")


@register_tasks_manager_onnx("stablelm", *COMMON_TEXT_GENERATION_TASKS)
class StableLMOnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.38.0")


@register_tasks_manager_onnx("olmo", *COMMON_TEXT_GENERATION_TASKS)
class OlmoOnnxConfig(LlamaOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
Expand Down
43 changes: 43 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,3 +1301,46 @@ def __exit__(self, exc_type, exc_value, traceback):

if is_diffusers_version(">=", "0.35.0"):
diffusers.models.transformers.transformer_flux.apply_rotary_emb = self.original_apply_rotary_emb


def patched_cohere_rotary_forward(self, x, position_ids):
# Get batch size and sequence length for manual expansion
batch_size, seq_len = position_ids.shape[:2]

# Instead of using expand, manually repeat the tensor.
# Problem with expand: it creates a view with shared memory rather than copying data,
# which causes ONNX export issues with dynamic shapes and view operations.
# Using repeat() ensures actual memory allocation and data copying for ONNX compatibility.
# original: inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
inv_freq_base = self.inv_freq[None, :, None].float() # Shape: [1, freq_dim, 1]
inv_freq_expanded = inv_freq_base.repeat(batch_size, 1, 1) # Shape: [batch_size, freq_dim, 1]

position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"

with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = freqs.repeat_interleave(2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


class CohereModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()

if is_transformers_version(">=", "4.38.0"):
from transformers.models.cohere.modeling_cohere import CohereRotaryEmbedding

self.original_forward = CohereRotaryEmbedding.forward
CohereRotaryEmbedding.forward = patched_cohere_rotary_forward

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)

if is_transformers_version(">=", "4.38.0"):
from transformers.models.cohere.modeling_cohere import CohereRotaryEmbedding

CohereRotaryEmbedding.forward = self.original_forward
3 changes: 3 additions & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,15 @@
MODEL_TYPES_REQUIRING_POSITION_IDS = {
"arcee",
"codegen",
"cohere",
"falcon",
"gemma",
"gpt2",
"gpt_bigcode",
"gpt_neo",
"gpt_neox",
"gptj",
"helium",
"imagegpt",
"internlm2",
"llama",
Expand All @@ -89,6 +91,7 @@
"qwen3_moe",
"granite",
"smollm3",
"stablelm",
"olmo2",
"olmo",
}
Expand Down
3 changes: 3 additions & 0 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,17 @@ def __init__(

if self.config.model_type in {
"arcee",
"cohere",
"gemma",
"helium",
"mistral",
"llama",
"qwen2",
"qwen3",
"qwen3_moe",
"granite",
"smollm3",
"stablelm",
}:
self.num_key_value_heads = self.config.num_key_value_heads

Expand Down
3 changes: 3 additions & 0 deletions tests/exporters/onnx/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"chinese_clip": "hf-internal-testing/tiny-random-ChineseCLIPModel",
"clip": "hf-internal-testing/tiny-random-CLIPModel",
"clip_vision_model": "fxmarty/clip-vision-model-tiny",
"cohere": "hf-internal-testing/tiny-random-CohereForCausalLM",
"colpali": "hf-internal-testing/tiny-random-ColPaliForRetrieval",
"convbert": "hf-internal-testing/tiny-random-ConvBertModel",
"convnext": "hf-internal-testing/tiny-random-convnext",
Expand Down Expand Up @@ -109,6 +110,7 @@
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"granite": "hf-internal-testing/tiny-random-GraniteForCausalLM",
"groupvit": "hf-internal-testing/tiny-random-groupvit",
"helium": "hf-internal-testing/tiny-random-HeliumForCausalLM",
"hiera": "hf-internal-testing/tiny-random-HieraForImageClassification",
"ibert": "hf-internal-testing/tiny-random-IBertModel",
"imagegpt": "hf-internal-testing/tiny-random-ImageGPTModel",
Expand Down Expand Up @@ -174,6 +176,7 @@
"siglip_vision_model": "hf-internal-testing/tiny-random-SiglipVisionModel",
"splinter": "hf-internal-testing/tiny-random-SplinterModel",
"squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel",
"stablelm": "hf-internal-testing/tiny-random-StableLmForCausalLM",
"swin": "hf-internal-testing/tiny-random-SwinModel",
"swinv2": "hf-internal-testing/tiny-random-Swinv2Model",
"swin2sr": "hf-internal-testing/tiny-random-Swin2SRModel",
Expand Down
9 changes: 9 additions & 0 deletions tests/onnxruntime/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
from optimum.exporters.onnx.model_configs import (
ArceeOnnxConfig,
BloomOnnxConfig,
CohereOnnxConfig,
GemmaOnnxConfig,
GraniteOnnxConfig,
HeliumOnnxConfig,
InternLM2OnnxConfig,
MPTOnnxConfig,
Olmo2OnnxConfig,
Expand All @@ -43,6 +45,7 @@
Qwen3MoeOnnxConfig,
Qwen3OnnxConfig,
SmolLM3OnnxConfig,
StableLMOnnxConfig,
)
from optimum.exporters.onnx.utils import MODEL_TYPES_REQUIRING_POSITION_IDS
from optimum.exporters.tasks import TasksManager
Expand Down Expand Up @@ -87,6 +90,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):

if is_transformers_version(">=", str(ArceeOnnxConfig.MIN_TRANSFORMERS_VERSION)):
SUPPORTED_ARCHITECTURES.append("arcee")
if is_transformers_version(">=", str(CohereOnnxConfig.MIN_TRANSFORMERS_VERSION)):
SUPPORTED_ARCHITECTURES.append("cohere")
if is_transformers_version(">=", str(OPTOnnxConfig.MIN_TRANSFORMERS_VERSION)):
SUPPORTED_ARCHITECTURES.append("opt")
if is_transformers_version(">=", str(PhiOnnxConfig.MIN_TRANSFORMERS_VERSION)):
Expand All @@ -105,6 +110,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES.append("mpt")
if is_transformers_version(">=", str(GraniteOnnxConfig.MIN_TRANSFORMERS_VERSION)):
SUPPORTED_ARCHITECTURES.append("granite")
if is_transformers_version(">=", str(HeliumOnnxConfig.MIN_TRANSFORMERS_VERSION)):
SUPPORTED_ARCHITECTURES.append("helium")
if is_transformers_version(">=", str(Phi3OnnxConfig.MIN_TRANSFORMERS_VERSION)):
SUPPORTED_ARCHITECTURES.append("phi3")
if is_transformers_version(">=", str(Qwen3OnnxConfig.MIN_TRANSFORMERS_VERSION)):
Expand All @@ -115,6 +122,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES.append("internlm2")
if is_transformers_version(">=", str(SmolLM3OnnxConfig.MIN_TRANSFORMERS_VERSION)):
SUPPORTED_ARCHITECTURES.append("smollm3")
if is_transformers_version(">=", str(StableLMOnnxConfig.MIN_TRANSFORMERS_VERSION)):
SUPPORTED_ARCHITECTURES.append("stablelm")

# base generation kwargs
TRUST_REMOTE_CODE_MODELS = {"internlm2"} # noqa: RUF012
Expand Down
3 changes: 3 additions & 0 deletions tests/onnxruntime/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"bloom": "hf-internal-testing/tiny-random-BloomModel",
"camembert": "hf-internal-testing/tiny-random-camembert",
"clip": "hf-internal-testing/tiny-random-CLIPModel",
"cohere": "hf-internal-testing/tiny-random-CohereForCausalLM",
"convbert": "hf-internal-testing/tiny-random-ConvBertModel",
"convnext": "hf-internal-testing/tiny-random-convnext",
"convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model",
Expand Down Expand Up @@ -74,6 +75,7 @@
"gptj": "hf-internal-testing/tiny-random-GPTJForCausalLM",
"granite": "hf-internal-testing/tiny-random-GraniteForCausalLM",
"groupvit": "hf-internal-testing/tiny-random-groupvit",
"helium": "hf-internal-testing/tiny-random-HeliumForCausalLM",
"hiera": "hf-internal-testing/tiny-random-HieraForImageClassification",
"hubert": "hf-internal-testing/tiny-random-HubertModel",
"ibert": "hf-internal-testing/tiny-random-IBertModel",
Expand Down Expand Up @@ -125,6 +127,7 @@
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
"stable-diffusion-3": "optimum-internal-testing/tiny-random-stable-diffusion-3",
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
"stablelm": "hf-internal-testing/tiny-random-StableLmForCausalLM",
"swin": "hf-internal-testing/tiny-random-SwinModel",
"swinv2": "hf-internal-testing/tiny-random-Swinv2Model",
"swin-window": "yujiepan/tiny-random-swin-patch4-window7-224",
Expand Down