Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Camembert
- CLIP
- CodeGen
- Cohere
- ConvBert
- ConvNext
- ConvNextV2
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ class Qwen2OnnxConfig(LlamaOnnxConfig):
pass


class CohereOnnxConfig(LlamaOnnxConfig):
pass


class GemmaOnnxConfig(LlamaOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@

MODEL_TYPES_REQUIRING_POSITION_IDS = {
"codegen",
"cohere",
"falcon",
"gemma",
"gpt2",
Expand Down
8 changes: 8 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,14 @@ class TasksManager:
"text-generation-with-past",
onnx="CodeGenOnnxConfig",
),
"cohere": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="CohereOnnxConfig",
),
"convbert": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def prepare_past_key_values(
if self.model_type == "gemma":
num_attention_heads = self.normalized_config.num_key_value_heads
embed_size_per_head = self.normalized_config.head_dim
elif self.model_type in {"mistral", "llama", "qwen2"}:
elif self.model_type in {"mistral", "llama", "cohere", "qwen2"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class ORTConfigManager:
"bloom": "gpt2",
"camembert": "bert",
"codegen": "gpt2",
"cohere": "gpt2",
"deberta": "bert",
"deberta-v2": "bert",
"distilbert": "bert",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"blenderbot",
"blenderbot-small",
"bloom",
"cohere",
"llama",
"mistral",
"mpt",
Expand Down
3 changes: 2 additions & 1 deletion tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"bloom": "hf-internal-testing/tiny-random-BloomModel",
"camembert": "hf-internal-testing/tiny-random-camembert",
"clip": "hf-internal-testing/tiny-random-CLIPModel",
"convbert": "hf-internal-testing/tiny-random-ConvBertModel",
"cohere": "hf-internal-testing/tiny-random-CohereModel",
"convnext": "hf-internal-testing/tiny-random-convnext",
"convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model",
"codegen": "hf-internal-testing/tiny-random-CodeGenModel",
Expand Down Expand Up @@ -210,6 +210,7 @@
"bloom": "hf-internal-testing/tiny-random-BloomModel", # Not using bigscience/bloom-560m because it goes OOM.
"camembert": "camembert-base",
"clip": "openai/clip-vit-base-patch32",
"cohere": "hf-internal-testing/tiny-random-CohereModel", # Not using CohereForAI/c4ai-command-r-plus because it is gated and/or goes OOM.
"convbert": "YituTech/conv-bert-base",
"convnext": "facebook/convnext-tiny-224",
"codegen": "hf-internal-testing/tiny-random-CodeGenModel", # Not using Salesforce/codegen-350M-multi because it takes too much time for testing.
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2248,6 +2248,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = [
"bloom",
"codegen",
"cohere",
"falcon",
"gemma",
"gpt2",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"bloom": "hf-internal-testing/tiny-random-BloomModel",
"camembert": "hf-internal-testing/tiny-random-camembert",
"clip": "hf-internal-testing/tiny-random-CLIPModel",
"cohere": "hf-internal-testing/tiny-random-CohereModel",
"convbert": "hf-internal-testing/tiny-random-ConvBertModel",
"convnext": "hf-internal-testing/tiny-random-convnext",
"convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model",
Expand Down