diff --git a/README.md b/README.md index 842c9725a7..2dc72f42dc 100644 --- a/README.md +++ b/README.md @@ -214,6 +214,7 @@ The following model architectures, tasks and device distributions have been vali | Qwen2 |
  • Single card
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Qwen2-MoE | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Gemma | :heavy_check_mark: |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | +| Cohere | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | T5 / Flan T5 | :heavy_check_mark: | :heavy_check_mark: |
  • [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)
  • [translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)
  • [question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20)
  • | | BART | |
  • Single card
  • |
  • [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)
  • [translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)
  • [question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20)
  • | | ViT | :heavy_check_mark: | :heavy_check_mark: |
  • [image classification](https://github.com/huggingface/optimum-habana/tree/main/examples/image-classification)
  • | diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 947e41b6fe..90752694b9 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -60,6 +60,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | Qwen2 |
  • Single card
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Qwen2-MoE | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Persimmon | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | +| Cohere | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | T5 / Flan T5 | ✅ | ✅ |
  • [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)
  • [translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)
  • [question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20)
  • | | BART | |
  • Single card
  • |
  • [summarization](https://github.com/huggingface/optimum-habana/tree/main/examples/summarization)
  • [translation](https://github.com/huggingface/optimum-habana/tree/main/examples/translation)
  • [question answering](https://github.com/huggingface/optimum-habana/tree/main/examples/question-answering#fine-tuning-t5-on-squad20)
  • | | ViT | ✅ | ✅ |
  • [image classification](https://github.com/huggingface/optimum-habana/tree/main/examples/image-classification)
  • | diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index bab7db4359..0e3059aeb4 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -105,6 +105,7 @@ "stablelm", "mamba", "deci", + "cohere", "qwen2_moe", "gemma", "whisper", diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 2fd24148be..91713a68a1 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -40,6 +40,8 @@ GaudiCLIPVisionTransformer, GaudiCodeGenAttention, GaudiCodeGenForCausalLM, + GaudiCohereDecoderLayer, + GaudiCohereForCausalLM, GaudiFalconAttention, GaudiFalconDecoderLayer, GaudiFalconForCausalLM, @@ -150,6 +152,8 @@ gaudi_check_and_enable_sdpa, gaudi_codegen_block_forward, gaudi_codegen_model_forward, + gaudi_cohere_attention_forward, + gaudi_cohere_model_forward, gaudi_conv1d_forward, gaudi_DetrConvModel_forward, gaudi_esm_for_protein_folding_forward, @@ -604,3 +608,9 @@ def adapt_transformers_to_gaudi(): transformers.AutoConfig.register("deci", DeciLMConfig) transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM) + + # Optimization for cohere on Gaudi + transformers.models.cohere.modeling_cohere.CohereDecoderLayer = GaudiCohereDecoderLayer + transformers.models.cohere.modeling_cohere.CohereForCausalLM = GaudiCohereForCausalLM + transformers.models.cohere.modeling_cohere.CohereModel.forward = gaudi_cohere_model_forward + transformers.models.cohere.modeling_cohere.CohereAttention.forward = gaudi_cohere_attention_forward diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 8c9a045efa..d1760881a4 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -45,6 +45,12 @@ gaudi_codegen_block_forward, gaudi_codegen_model_forward, ) +from .cohere import ( + GaudiCohereDecoderLayer, + GaudiCohereForCausalLM, + gaudi_cohere_attention_forward, + gaudi_cohere_model_forward, +) from .decilm import ( DeciLMConfig, DeciLMForCausalLM, diff --git a/optimum/habana/transformers/models/cohere/__init__.py b/optimum/habana/transformers/models/cohere/__init__.py new file mode 100644 index 0000000000..ec3a43831c --- /dev/null +++ b/optimum/habana/transformers/models/cohere/__init__.py @@ -0,0 +1,6 @@ +from .modeling_cohere import ( + GaudiCohereDecoderLayer, + GaudiCohereForCausalLM, + gaudi_cohere_attention_forward, + gaudi_cohere_model_forward, +) diff --git a/optimum/habana/transformers/models/cohere/modeling_cohere.py b/optimum/habana/transformers/models/cohere/modeling_cohere.py new file mode 100644 index 0000000000..c0785c88ed --- /dev/null +++ b/optimum/habana/transformers/models/cohere/modeling_cohere.py @@ -0,0 +1,441 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.cohere.modeling_cohere import ( + Cache, + CohereAttention, + CohereConfig, + CohereDecoderLayer, + CohereForCausalLM, + CohereLayerNorm, + CohereMLP, + DynamicCache, + StaticCache, + apply_rotary_pos_emb, + logger, + repeat_kv, +) + +from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask + + +def gaudi_cohere_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from CohereAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py + The only differences are: + - add new args token_idx + - optimize KV cache + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + if self.use_qk_norm: + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + if token_idx is not None: + if len(past_key_value.key_cache) <= self.layer_idx: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) + past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class GaudiCohereDecoderLayer(CohereDecoderLayer): + def __init__(self, config: CohereConfig, layer_idx: int): + super(CohereDecoderLayer, self).__init__() + self.hidden_size = config.hidden_size + + self.self_attn = CohereAttention(config=config, layer_idx=layer_idx) + + self.mlp = CohereMLP(config) + self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Copied from CohereDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py + The only differences are: + - add new args token_idx + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + ) + + # Fully Connected + hidden_states_mlp = self.mlp(hidden_states) + + # Add everything together + hidden_states = residual + hidden_states_attention + hidden_states_mlp + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +def gaudi_cohere_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Copied from CohereModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py + The only differences are: + - add new args token_idx + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = 0 + return_legacy_cache = False + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + token_idx=token_idx, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class GaudiCohereForCausalLM(CohereForCausalLM): + """ + Inherits from CohereForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/modeling_cohere.py + The only differences are: + - add new args token_idx + - add token_idx into model_inputs + - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx + - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + token_idx=token_idx, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits * self.logit_scale + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + token_idx = kwargs.get("token_idx", None) + if past_key_values is not None: + if token_idx is None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "token_idx": token_idx, + } + ) + return model_inputs diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index a17333cf68..588bb37e8a 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -45,6 +45,7 @@ ("Qwen/Qwen2-7B", 512, False, 9669.45787), ("Qwen/Qwen1.5-MoE-A2.7B", 1, True, 44.25834541569395), ("EleutherAI/gpt-neo-2.7B", 1, False, 257.2476416844122), + ("CohereForAI/c4ai-command-r-v01", 1, False, 29.50315234651154), ], "fp8": [ ("tiiuae/falcon-180B", 4, 950, True, 128, 128, 2506.68),