diff --git a/README.md b/README.md
index 842c9725a7..2dc72f42dc 100644
--- a/README.md
+++ b/README.md
@@ -214,6 +214,7 @@ The following model architectures, tasks and device distributions have been vali
| Qwen2 |
Single card | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Qwen2-MoE | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Gemma | :heavy_check_mark: | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
+| Cohere | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| T5 / Flan T5 | :heavy_check_mark: | :heavy_check_mark: | [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)[translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20) |
| BART | | Single card | [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)[translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20) |
| ViT | :heavy_check_mark: | :heavy_check_mark: | [image classification](https://github.com/huggingface/optimum-habana/tree/main/examples/image-classification) |
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index 947e41b6fe..90752694b9 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -60,6 +60,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| Qwen2 | Single card | Single card | [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Qwen2-MoE | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| Persimmon | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
+| Cohere | | Single card | [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation) |
| T5 / Flan T5 | ✅ | ✅ | [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)[translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20) |
| BART | | Single card | [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)[translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)[question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20) |
| ViT | ✅ | ✅ | [image classification](https://github.com/huggingface/optimum-habana/tree/main/examples/image-classification) |
diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py
index bab7db4359..0e3059aeb4 100644
--- a/optimum/habana/transformers/generation/utils.py
+++ b/optimum/habana/transformers/generation/utils.py
@@ -105,6 +105,7 @@
"stablelm",
"mamba",
"deci",
+ "cohere",
"qwen2_moe",
"gemma",
"whisper",
diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py
index 2fd24148be..91713a68a1 100644
--- a/optimum/habana/transformers/modeling_utils.py
+++ b/optimum/habana/transformers/modeling_utils.py
@@ -40,6 +40,8 @@
GaudiCLIPVisionTransformer,
GaudiCodeGenAttention,
GaudiCodeGenForCausalLM,
+ GaudiCohereDecoderLayer,
+ GaudiCohereForCausalLM,
GaudiFalconAttention,
GaudiFalconDecoderLayer,
GaudiFalconForCausalLM,
@@ -150,6 +152,8 @@
gaudi_check_and_enable_sdpa,
gaudi_codegen_block_forward,
gaudi_codegen_model_forward,
+ gaudi_cohere_attention_forward,
+ gaudi_cohere_model_forward,
gaudi_conv1d_forward,
gaudi_DetrConvModel_forward,
gaudi_esm_for_protein_folding_forward,
@@ -604,3 +608,9 @@ def adapt_transformers_to_gaudi():
transformers.AutoConfig.register("deci", DeciLMConfig)
transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM)
+
+ # Optimization for cohere on Gaudi
+ transformers.models.cohere.modeling_cohere.CohereDecoderLayer = GaudiCohereDecoderLayer
+ transformers.models.cohere.modeling_cohere.CohereForCausalLM = GaudiCohereForCausalLM
+ transformers.models.cohere.modeling_cohere.CohereModel.forward = gaudi_cohere_model_forward
+ transformers.models.cohere.modeling_cohere.CohereAttention.forward = gaudi_cohere_attention_forward
diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py
index 8c9a045efa..d1760881a4 100644
--- a/optimum/habana/transformers/models/__init__.py
+++ b/optimum/habana/transformers/models/__init__.py
@@ -45,6 +45,12 @@
gaudi_codegen_block_forward,
gaudi_codegen_model_forward,
)
+from .cohere import (
+ GaudiCohereDecoderLayer,
+ GaudiCohereForCausalLM,
+ gaudi_cohere_attention_forward,
+ gaudi_cohere_model_forward,
+)
from .decilm import (
DeciLMConfig,
DeciLMForCausalLM,
diff --git a/optimum/habana/transformers/models/cohere/__init__.py b/optimum/habana/transformers/models/cohere/__init__.py
new file mode 100644
index 0000000000..ec3a43831c
--- /dev/null
+++ b/optimum/habana/transformers/models/cohere/__init__.py
@@ -0,0 +1,6 @@
+from .modeling_cohere import (
+ GaudiCohereDecoderLayer,
+ GaudiCohereForCausalLM,
+ gaudi_cohere_attention_forward,
+ gaudi_cohere_model_forward,
+)
diff --git a/optimum/habana/transformers/models/cohere/modeling_cohere.py b/optimum/habana/transformers/models/cohere/modeling_cohere.py
new file mode 100644
index 0000000000..c0785c88ed
--- /dev/null
+++ b/optimum/habana/transformers/models/cohere/modeling_cohere.py
@@ -0,0 +1,441 @@
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.models.cohere.modeling_cohere import (
+ Cache,
+ CohereAttention,
+ CohereConfig,
+ CohereDecoderLayer,
+ CohereForCausalLM,
+ CohereLayerNorm,
+ CohereMLP,
+ DynamicCache,
+ StaticCache,
+ apply_rotary_pos_emb,
+ logger,
+ repeat_kv,
+)
+
+from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask
+
+
+def gaudi_cohere_attention_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Copied from CohereAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py
+ The only differences are:
+ - add new args token_idx
+ - optimize KV cache
+ """
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ if self.use_qk_norm:
+ query_states = self.q_norm(query_states)
+ key_states = self.k_norm(key_states)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ if token_idx is not None:
+ if len(past_key_value.key_cache) <= self.layer_idx:
+ past_key_value.key_cache.append(key_states)
+ past_key_value.value_cache.append(value_states)
+ else:
+ past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states)
+ past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states)
+ key_states = past_key_value.key_cache[self.layer_idx]
+ value_states = past_key_value.value_cache[self.layer_idx]
+ else:
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class GaudiCohereDecoderLayer(CohereDecoderLayer):
+ def __init__(self, config: CohereConfig, layer_idx: int):
+ super(CohereDecoderLayer, self).__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = CohereAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = CohereMLP(config)
+ self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Copied from CohereDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py
+ The only differences are:
+ - add new args token_idx
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ )
+
+ # Fully Connected
+ hidden_states_mlp = self.mlp(hidden_states)
+
+ # Add everything together
+ hidden_states = residual + hidden_states_attention + hidden_states_mlp
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+def gaudi_cohere_model_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+) -> Union[Tuple, BaseModelOutputWithPast]:
+ """
+ Copied from CohereModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py
+ The only differences are:
+ - add new args token_idx
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ past_seen_tokens = 0
+ return_legacy_cache = False
+ if (
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class GaudiCohereForCausalLM(CohereForCausalLM):
+ """
+ Inherits from CohereForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py
+ The only differences are:
+ - add new args token_idx
+ - add token_idx into model_inputs
+ - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx
+ - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx
+ """
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits * self.logit_scale
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ **kwargs,
+ ):
+ token_idx = kwargs.get("token_idx", None)
+ if past_key_values is not None:
+ if token_idx is None:
+ if inputs_embeds is not None: # Exception 1
+ input_ids = input_ids[:, -cache_position.shape[0] :]
+ elif (
+ input_ids.shape[1] != cache_position.shape[0]
+ ): # Default case (the "else", a no op, is Exception 2)
+ input_ids = input_ids[:, cache_position]
+ else:
+ input_ids = torch.index_select(input_ids, 1, token_idx - 1)
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ if token_idx is not None:
+ position_ids = torch.index_select(position_ids, 1, token_idx - 1)
+ else:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and cache_position[0] == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
+ else:
+ # The clone here is for the same reason as for `position_ids`.
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
+
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
+ if model_inputs["inputs_embeds"] is not None:
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
+ device = model_inputs["inputs_embeds"].device
+ else:
+ batch_size, sequence_length = model_inputs["input_ids"].shape
+ device = model_inputs["input_ids"].device
+
+ dtype = self.lm_head.weight.dtype
+ min_dtype = torch.finfo(dtype).min
+
+ attention_mask = _gaudi_prepare_4d_causal_attention_mask(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=past_key_values.get_max_length(),
+ dtype=dtype,
+ device=device,
+ min_dtype=min_dtype,
+ cache_position=cache_position,
+ batch_size=batch_size,
+ )
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ "token_idx": token_idx,
+ }
+ )
+ return model_inputs
diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py
index a17333cf68..588bb37e8a 100644
--- a/tests/test_text_generation_example.py
+++ b/tests/test_text_generation_example.py
@@ -45,6 +45,7 @@
("Qwen/Qwen2-7B", 512, False, 9669.45787),
("Qwen/Qwen1.5-MoE-A2.7B", 1, True, 44.25834541569395),
("EleutherAI/gpt-neo-2.7B", 1, False, 257.2476416844122),
+ ("CohereForAI/c4ai-command-r-v01", 1, False, 29.50315234651154),
],
"fp8": [
("tiiuae/falcon-180B", 4, 950, True, 128, 128, 2506.68),