From 3b1a66032048cb0479c3b0cbf90f0ea97211c46d Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 15 Oct 2024 00:20:28 -0700 Subject: [PATCH 01/10] add mllama support Signed-off-by: Wang, Yi A --- examples/image-to-text/run_pipeline.py | 22 +- .../habana/transformers/generation/utils.py | 7 + optimum/habana/transformers/modeling_utils.py | 16 + .../habana/transformers/models/__init__.py | 9 + .../transformers/models/mllama/__init__.py | 9 + .../models/mllama/modeling_mllama.py | 850 ++++++++++++++++++ 6 files changed, 906 insertions(+), 7 deletions(-) create mode 100644 optimum/habana/transformers/models/mllama/__init__.py create mode 100644 optimum/habana/transformers/models/mllama/modeling_mllama.py diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index d80939b43f..7421bd3a90 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -23,7 +23,7 @@ import PIL.Image import requests import torch -from transformers import AutoConfig, LlavaNextProcessor, LlavaProcessor, pipeline +from transformers import AutoConfig, AutoProcessor, pipeline from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi @@ -155,17 +155,14 @@ def main(): adapt_transformers_to_gaudi() model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type - if args.image_path is None and model_type == "llava": + if args.image_path is None and model_type in ["llava", "mllama"]: args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"] elif args.image_path is None and model_type == "llava_next": args.image_path = [ "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" ] - if args.prompt is None and model_type in ("llava", "llava_next"): - if model_type == "llava": - processor = LlavaProcessor.from_pretrained(args.model_name_or_path) - elif model_type == "llava_next": - processor = LlavaNextProcessor.from_pretrained(args.model_name_or_path) + if args.prompt is None and model_type in ["llava", "llava_next", "mllama"]: + processor = AutoProcessor.from_pretrained(args.model_name_or_path) conversation = [ { "role": "user", @@ -228,6 +225,17 @@ def main(): if args.quant_config: generator.model = setup_quantization(generator.model, args) + # delete once pipeline integrate AutoProcessor as preprocess engine + if model_type in ["mllama"]: + from transformers.image_utils import load_image + + def preprocess(self, image, prompt=None, timeout=None): + image = load_image(image, timeout=timeout) + model_inputs = processor(images=image, text=prompt, return_tensors=self.framework) + return model_inputs + + generator.__class__.preprocess = preprocess + # warm up for i in range(args.warmup): generator(images, prompt=args.prompt, batch_size=args.batch_size, generate_kwargs=generate_kwargs) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index bab7db4359..d8a29cb807 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -108,6 +108,7 @@ "qwen2_moe", "gemma", "whisper", + "mllama", ] @@ -1103,6 +1104,12 @@ def generate( (0, generation_config.max_new_tokens), value=0, ) + if model_kwargs.get("cross_attention_mask") is not None: + model_kwargs["cross_attention_mask"] = torch.nn.functional.pad( + model_kwargs["cross_attention_mask"], + (0, 0, 0, 0, 0, generation_config.max_new_tokens), + value=0, + ) else: assert generation_config.bucket_size <= 0, "Untested path for bucket>0" token_idx = 1 diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 2fd24148be..f8224365a5 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -82,6 +82,13 @@ GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, GaudiMixtralModel, + GaudiMllamaCrossAttentionDecoderLayer, + GaudiMllamaForCausalLM, + GaudiMllamaForConditionalGeneration, + GaudiMllamaSelfAttentionDecoderLayer, + GaudiMllamaTextCrossAttention, + GaudiMllamaTextModel, + GaudiMllamaTextSelfAttention, GaudiMptAttention, GaudiMptBlock, GaudiMptForCausalLM, @@ -602,5 +609,14 @@ def adapt_transformers_to_gaudi(): transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration = GaudiWhisperForConditionalGeneration transformers.models.whisper.modeling_whisper.WHISPER_ATTENTION_CLASSES = GAUDI_WHISPER_ATTENTION_CLASSES + # Optimization for mllama on Gaudi + transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer = GaudiMllamaSelfAttentionDecoderLayer + transformers.models.mllama.modeling_mllama.MllamaCrossAttentionDecoderLayer = GaudiMllamaCrossAttentionDecoderLayer + transformers.models.mllama.modeling_mllama.MllamaForCausalLM = GaudiMllamaForCausalLM + transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention = GaudiMllamaTextSelfAttention + transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention = GaudiMllamaTextCrossAttention + transformers.models.mllama.modeling_mllama.MllamaForConditionalGeneration = GaudiMllamaForConditionalGeneration + transformers.models.mllama.modeling_mllama.MllamaTextModel = GaudiMllamaTextModel + transformers.AutoConfig.register("deci", DeciLMConfig) transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM) diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 8c9a045efa..33b349d3d3 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -138,6 +138,15 @@ gaudi_mixtral_block_sparse_moe_forward, gaudi_mixtral_rmsnorm_forward, ) +from .mllama import ( + GaudiMllamaCrossAttentionDecoderLayer, + GaudiMllamaForCausalLM, + GaudiMllamaForConditionalGeneration, + GaudiMllamaSelfAttentionDecoderLayer, + GaudiMllamaTextCrossAttention, + GaudiMllamaTextModel, + GaudiMllamaTextSelfAttention, +) from .modeling_all_models import ( gaudi_check_and_enable_sdpa, gaudi_conv1d_forward, diff --git a/optimum/habana/transformers/models/mllama/__init__.py b/optimum/habana/transformers/models/mllama/__init__.py new file mode 100644 index 0000000000..03108cbdd4 --- /dev/null +++ b/optimum/habana/transformers/models/mllama/__init__.py @@ -0,0 +1,9 @@ +from .modeling_mllama import ( + GaudiMllamaCrossAttentionDecoderLayer, + GaudiMllamaForCausalLM, + GaudiMllamaForConditionalGeneration, + GaudiMllamaSelfAttentionDecoderLayer, + GaudiMllamaTextCrossAttention, + GaudiMllamaTextModel, + GaudiMllamaTextSelfAttention, +) diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py new file mode 100644 index 0000000000..f9b07f00ae --- /dev/null +++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py @@ -0,0 +1,850 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig +from transformers.models.mllama.modeling_mllama import ( + MllamaCrossAttentionDecoderLayer, + MllamaForCausalLM, + MllamaForConditionalGeneration, + MllamaSelfAttentionDecoderLayer, + MllamaTextCrossAttention, + MllamaTextModel, + MllamaTextSelfAttention, + _prepare_4d_causal_attention_mask_with_cache_position, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import ( + logging, +) + + +logger = logging.get_logger(__name__) + + +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Copied from _prepare_cross_attention_mask: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L99 + The only differences are: + - if there's pading in cross_attention_mask in the right. do not masked it, or else it will impact softmax in crossattention + """ + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + pad_idx = torch.nonzero(cross_attention_mask)[-1][1] + cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + full_text_row_masked_out_mask2 = full_text_row_masked_out_mask.clone() + full_text_row_masked_out_mask2[:, :, pad_idx + 1 :, :] = 1 + cross_attention_mask *= full_text_row_masked_out_mask2 + + return cross_attention_mask, full_text_row_masked_out_mask + + +class GaudiMllamaTextCrossAttention(MllamaTextCrossAttention): + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from MllamaTextCrossAttention::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L512 + The only differences are: + - add token_idx support + - add support if past_key_value is not Cache + """ + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_norm(query_states) + + if cross_attention_states is not None: + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = self.k_norm(key_states) + if past_key_value is not None: + # if we have a new image + new tokens, we only computed key_states on that new image + # we still update the cross key states, past_image, new_image. And use it! + if isinstance(past_key_value, Cache): + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + else: + if token_idx is not None: + past_key_value[0].index_copy_(2, token_idx - 1, key_states) + past_key_value[1].index_copy_(2, token_idx - 1, value_states) + key_states = past_key_value[0] + value_states = past_key_value[1] + else: + key_states = torch.cat((past_key_value[0], key_states), dim=2) + value_states = torch.cat((past_key_value[1], value_states), dim=2) + if use_cache and not isinstance(past_key_value, Cache): + past_key_value = (key_states, value_states) + elif cache_position[0] != 0: + if isinstance(past_key_value, Cache): + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + key_states, value_states = (past_key_value[0], past_key_value[1]) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class GaudiMllamaTextSelfAttention(MllamaTextSelfAttention): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + token_idx: Optional[torch.Tensor] = None, + **kwargs, + ): + """ + Copied from MllamaTextSelfAttention::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L733 + The only differences are: + - add token_idx support + - add support if past_key_value is not Cache + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + if isinstance(past_key_value, Cache): + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + else: + if token_idx is not None: + past_key_value[0].index_copy_(2, token_idx - 1, key_states) + past_key_value[1].index_copy_(2, token_idx - 1, value_states) + key_states = past_key_value[0] + value_states = past_key_value[1] + else: + key_states = torch.cat((past_key_value[0], key_states), dim=2) + value_states = torch.cat((past_key_value[1], value_states), dim=2) + if use_cache and not isinstance(past_key_value, Cache): + past_key_value = (key_states, value_states) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer +class GaudiMllamaSelfAttentionDecoderLayer(MllamaSelfAttentionDecoderLayer): + def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: + super(GaudiMllamaSelfAttentionDecoderLayer, self).__init__(config, layer_idx) + self.self_attn = GaudiMllamaTextSelfAttention(config, layer_idx=layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + token_idx: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Copied from MllamaSelfAttentionDecoderLayer::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L904 + The only differences are: + - add token_idx input + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class GaudiMllamaCrossAttentionDecoderLayer(MllamaCrossAttentionDecoderLayer): + def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: + super(GaudiMllamaCrossAttentionDecoderLayer, self).__init__(config, layer_idx) + self.cross_attn = GaudiMllamaTextCrossAttention(config, layer_idx=layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + attention_mask: torch.Tensor, + full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + token_idx: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor]: + """ + Copied from MllamaCrossAttentionDecoderLayer::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L989 + The only differences are: + - add token_idx support + - pass use_cache to cross_attn + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, attn_weights, past_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, + use_cache=use_cache, + token_idx=token_idx, + ) + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if full_text_row_masked_out_mask is not None: + hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + return outputs + + +class GaudiMllamaTextModel(MllamaTextModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.FloatTensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Copied from MllamaTextModel::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1617 + The only differences are: + - add token_idx support + - add support if past_key_value is not Cache + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + if cache_position is None: + if isinstance(past_key_values, Cache): + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + else: + past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if token_idx is not None and past_key_values is not None: + past_seen_tokens = token_idx + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None if isinstance(past_key_values, Cache) else () + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # For text-only path we should skip cross attention layers. + # Let's check if the layer is cross attention layer and if we have cross attention states + # or cached cross attention states. + is_cross_attention_layer = idx in self.cross_attention_layers + is_cross_attention_cache_empty = past_key_values is None or ( + past_key_values is not None and past_key_values.get_seq_length(idx) == 0 + if isinstance(past_key_values, Cache) + else False + ) + + if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: + continue + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + cross_attention_states, + cross_attention_mask, + causal_mask, + full_text_row_masked_out_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + if isinstance(past_key_values, Cache): + past_key_value = past_key_values + else: + past_key_value = None if past_key_values is None else past_key_values[idx] + layer_outputs = decoder_layer( + hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + attention_mask=causal_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + if isinstance(past_key_values, Cache): + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + else: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + """ + Copied from MllamaTextModel::_update_causal_mask: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1768 + The only differences are: + - add support if past_key_value is not Cache + """ + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + if isinstance(past_key_values, Cache): + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + else: + past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + # TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line + # self.config._attn_implementation == "sdpa" and + if self.config._attn_implementation == "sdpa" and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class GaudiMllamaForCausalLM(MllamaForCausalLM): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + token_idx: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Copied from MllamaForCausalLM::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1871 + The only differences are: + - add token_idx input + - add logits handle if token_idx is not None + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + cross_attention_states=cross_attention_states, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + token_idx=token_idx, + ) + + hidden_states = outputs[0] + + if token_idx is None and num_logits_to_keep != 0: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + else: + logits = self.lm_head(hidden_states).float() + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class GaudiMllamaForConditionalGeneration(MllamaForConditionalGeneration): + def __init__(self, config: MllamaConfig): + config._attn_implementation = "eager" + super(GaudiMllamaForConditionalGeneration, self).__init__(config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + token_idx: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077 + The only differences are: + - add token_idx input + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") + # get vision tokens from vision model + vision_outputs = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.hidden_size + ) + + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None and cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + labels=labels, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + token_idx=token_idx, + ) + + return outputs + + def prepare_inputs_for_generation( + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + position_ids=None, + pixel_values=None, + aspect_ratio_ids=None, + aspect_ratio_mask=None, + cross_attention_mask=None, + past_key_values=None, + use_cache=False, + cache_position=None, + num_logits_to_keep=None, + **kwargs, + ): + """ + Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208 + The only differences are: + - add token_idx handling + """ + token_idx = kwargs.get("token_idx", None) + if past_key_values is not None: + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + elif inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + if token_idx is not None and past_key_values: + cache_position = torch.arange(token_idx - 1, token_idx - 1 + input_ids.shape[1], device=input_ids.device) + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "cross_attention_mask": cross_attention_mask, + "token_idx": token_idx, + } + ) + + # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios + # to compute image hidden states, otherwise they are cached within each cross attn layer + if (input_ids == self.config.image_token_index).any(): + model_inputs["pixel_values"] = pixel_values + model_inputs["aspect_ratio_ids"] = aspect_ratio_ids + model_inputs["aspect_ratio_mask"] = aspect_ratio_mask + + return model_inputs From 866ff92d54d478e9208134b58338c20fd9f1076e Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 15 Oct 2024 01:49:27 -0700 Subject: [PATCH 02/10] add cross attention mask handling Signed-off-by: Wang, Yi A --- .../models/mllama/modeling_mllama.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py index f9b07f00ae..878ef1ac64 100644 --- a/optimum/habana/transformers/models/mllama/modeling_mllama.py +++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py @@ -788,7 +788,7 @@ def prepare_inputs_for_generation( **kwargs, ): """ - Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208 + Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208 The only differences are: - add token_idx handling """ @@ -848,3 +848,30 @@ def prepare_inputs_for_generation( model_inputs["aspect_ratio_mask"] = aspect_ratio_mask return model_inputs + + def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): + """ + Copied from MllamaForConditionalGeneration::_update_model_kwargs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2274 + The only differences are: + - add token_idx handling + """ + cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + # add cross-attn mask for new token + if cross_attention_mask_prev is not None: + token_idx = model_kwargs.get("token_idx", None) + if token_idx is None: + model_kwargs["cross_attention_mask"] = torch.cat( + [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 + ) + else: + model_kwargs["cross_attention_mask"][:, token_idx - 1, ...] = cross_attention_mask_prev[ + :, token_idx - 2, ... + ].clone() + return model_kwargs From 1981f041860925609de9004cb5dc6b5fb9b9ee99 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 15 Oct 2024 19:29:45 -0700 Subject: [PATCH 03/10] remove unused code Signed-off-by: Wang, Yi A --- optimum/habana/transformers/models/mllama/modeling_mllama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py index 878ef1ac64..d5cfd77010 100644 --- a/optimum/habana/transformers/models/mllama/modeling_mllama.py +++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py @@ -412,8 +412,6 @@ def forward( past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 else: past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if token_idx is not None and past_key_values is not None: - past_seen_tokens = token_idx cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) From afeedc4a4886f7858551f8a1da6b5fcd861ec805 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Wed, 16 Oct 2024 03:10:07 -0700 Subject: [PATCH 04/10] add bucket internal for mllama Signed-off-by: Wang, Yi A --- .../habana/transformers/generation/utils.py | 11 +++++---- .../models/mllama/modeling_mllama.py | 23 ++++++++++++++----- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d8a29cb807..9e01608024 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -324,11 +324,13 @@ def _expand_dict_for_generation(dict_to_expand): def _pad_past_key_values(self, model_kwargs): pad_amount = model_kwargs.get("kv_cache_pad_len", 0) + kv_cache_len = model_kwargs.get("kv_cache_len", 0) if model_kwargs["past_key_values"]: if model_kwargs.get("mqa_model", False): for i in range(len(model_kwargs["past_key_values"])): # layer - if torch.is_tensor( - model_kwargs["past_key_values"][i] + if ( + torch.is_tensor(model_kwargs["past_key_values"][i]) + and model_kwargs["past_key_values"][i].shape[-2] == kv_cache_len - pad_amount ): # tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked model_kwargs["past_key_values"][i] = torch.nn.functional.pad( model_kwargs["past_key_values"][i], (0, 0, 0, pad_amount) @@ -338,8 +340,9 @@ def _pad_past_key_values(self, model_kwargs): else: for i in range(len(model_kwargs["past_key_values"])): # layer for j in range(len(model_kwargs["past_key_values"][i])): # k or v - if torch.is_tensor( - model_kwargs["past_key_values"][i][j] + if ( + torch.is_tensor(model_kwargs["past_key_values"][i][j]) + and model_kwargs["past_key_values"][i][j].shape[-2] == kv_cache_len - pad_amount ): # tensor(batch_size, n_heads, kv_cache_len, head_dim) model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad( model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount) diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py index d5cfd77010..df4e1bf40e 100644 --- a/optimum/habana/transformers/models/mllama/modeling_mllama.py +++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py @@ -49,6 +49,7 @@ def _prepare_cross_attention_mask( cross_attention_mask: torch.Tensor, num_vision_tokens: int, dtype: str, + token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Copied from _prepare_cross_attention_mask: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L99 @@ -57,7 +58,6 @@ def _prepare_cross_attention_mask( """ # reshape so it can be used by attn module batch_size, text_total_length, *_ = cross_attention_mask.shape - pad_idx = torch.nonzero(cross_attention_mask)[-1][1] cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3) cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) cross_attention_mask = cross_attention_mask.unsqueeze(1) @@ -74,9 +74,12 @@ def _prepare_cross_attention_mask( full_text_row_masked_out_mask = ( (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] ) - full_text_row_masked_out_mask2 = full_text_row_masked_out_mask.clone() - full_text_row_masked_out_mask2[:, :, pad_idx + 1 :, :] = 1 - cross_attention_mask *= full_text_row_masked_out_mask2 + if token_idx is not None: + full_text_row_masked_out_mask2 = full_text_row_masked_out_mask.clone() + full_text_row_masked_out_mask2[:, :, token_idx:, :] = 1 + cross_attention_mask *= full_text_row_masked_out_mask2 + else: + cross_attention_mask *= full_text_row_masked_out_mask return cross_attention_mask, full_text_row_masked_out_mask @@ -131,7 +134,7 @@ def forward( key_states = torch.cat((past_key_value[0], key_states), dim=2) value_states = torch.cat((past_key_value[1], value_states), dim=2) if use_cache and not isinstance(past_key_value, Cache): - past_key_value = (key_states, value_states) + past_key_value = [key_states, value_states] elif cache_position[0] != 0: if isinstance(past_key_value, Cache): key_states, value_states = ( @@ -213,7 +216,7 @@ def forward( key_states = torch.cat((past_key_value[0], key_states), dim=2) value_states = torch.cat((past_key_value[1], value_states), dim=2) if use_cache and not isinstance(past_key_value, Cache): - past_key_value = (key_states, value_states) + past_key_value = [key_states, value_states] key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -740,6 +743,7 @@ def forward( cross_attention_mask, num_vision_tokens=self.vision_model.num_patches, dtype=self.dtype, + token_idx=token_idx, ) else: full_text_row_masked_out_mask = None @@ -789,8 +793,10 @@ def prepare_inputs_for_generation( Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208 The only differences are: - add token_idx handling + - add bucket_internal handling """ token_idx = kwargs.get("token_idx", None) + bucket_internal = kwargs.get("bucket_internal", None) if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) @@ -798,6 +804,11 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] + elif bucket_internal and token_idx is not None: + # for the 1st token we can slice the inputs till token idx for the fwd pass. + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] + cache_position = cache_position[:token_idx] # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way if attention_mask is not None and position_ids is None: From ca62f24c8f720a5fe710029a6870529b78390bad Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 18 Oct 2024 07:29:13 -0700 Subject: [PATCH 05/10] optimize vision model and set cache_position to None to fix hpu graph issue Signed-off-by: Wang, Yi A --- optimum/habana/transformers/modeling_utils.py | 2 + .../habana/transformers/models/__init__.py | 1 + .../transformers/models/mllama/__init__.py | 1 + .../models/mllama/modeling_mllama.py | 248 +++++++++++++++--- 4 files changed, 214 insertions(+), 38 deletions(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index f8224365a5..9be3558cae 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -89,6 +89,7 @@ GaudiMllamaTextCrossAttention, GaudiMllamaTextModel, GaudiMllamaTextSelfAttention, + GaudiMllamaVisionModel, GaudiMptAttention, GaudiMptBlock, GaudiMptForCausalLM, @@ -617,6 +618,7 @@ def adapt_transformers_to_gaudi(): transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention = GaudiMllamaTextCrossAttention transformers.models.mllama.modeling_mllama.MllamaForConditionalGeneration = GaudiMllamaForConditionalGeneration transformers.models.mllama.modeling_mllama.MllamaTextModel = GaudiMllamaTextModel + transformers.models.mllama.modeling_mllama.MllamaVisionModel = GaudiMllamaVisionModel transformers.AutoConfig.register("deci", DeciLMConfig) transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM) diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 33b349d3d3..16b9f766b3 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -146,6 +146,7 @@ GaudiMllamaTextCrossAttention, GaudiMllamaTextModel, GaudiMllamaTextSelfAttention, + GaudiMllamaVisionModel, ) from .modeling_all_models import ( gaudi_check_and_enable_sdpa, diff --git a/optimum/habana/transformers/models/mllama/__init__.py b/optimum/habana/transformers/models/mllama/__init__.py index 03108cbdd4..198f1cc2aa 100644 --- a/optimum/habana/transformers/models/mllama/__init__.py +++ b/optimum/habana/transformers/models/mllama/__init__.py @@ -6,4 +6,5 @@ GaudiMllamaTextCrossAttention, GaudiMllamaTextModel, GaudiMllamaTextSelfAttention, + GaudiMllamaVisionModel, ) diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py index df4e1bf40e..28a61d5901 100644 --- a/optimum/habana/transformers/models/mllama/modeling_mllama.py +++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py @@ -18,12 +18,13 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig from transformers.models.mllama.modeling_mllama import ( MllamaCrossAttentionDecoderLayer, @@ -33,7 +34,9 @@ MllamaTextCrossAttention, MllamaTextModel, MllamaTextSelfAttention, + MllamaVisionModel, _prepare_4d_causal_attention_mask_with_cache_position, + _prepare_aspect_ratio_attention_mask, apply_rotary_pos_emb, repeat_kv, ) @@ -41,6 +44,10 @@ logging, ) +from ...modeling_attn_mask_utils import ( + _gaudi_prepare_4d_causal_attention_mask, +) + logger = logging.get_logger(__name__) @@ -101,6 +108,7 @@ def forward( The only differences are: - add token_idx support - add support if past_key_value is not Cache + - cache position is None """ """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() @@ -135,14 +143,13 @@ def forward( value_states = torch.cat((past_key_value[1], value_states), dim=2) if use_cache and not isinstance(past_key_value, Cache): past_key_value = [key_states, value_states] - elif cache_position[0] != 0: - if isinstance(past_key_value, Cache): - key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], - ) - else: - key_states, value_states = (past_key_value[0], past_key_value[1]) + elif not isinstance(past_key_value, Cache) and past_key_value is not None: + key_states, value_states = (past_key_value[0], past_key_value[1]) + elif cache_position is not None and cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) else: raise ValueError( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" @@ -409,21 +416,37 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds - - if cache_position is None: - if isinstance(past_key_values, Cache): - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - else: - past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + if isinstance(past_key_values, Cache): + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + else: + past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ignore_cache_position = True # Ignoring cache position for HPU, or else hpu graph may has issue + if ignore_cache_position is False: + if cache_position is None: + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + else: + if position_ids is None: + position_ids = torch.arange( + past_seen_tokens, + inputs_embeds.shape[1] + past_seen_tokens, + dtype=torch.long, + device=inputs_embeds.device, + ) + position_ids = position_ids.unsqueeze(0) + cache_position = None + causal_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape, + inputs_embeds, + past_seen_tokens, ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -673,7 +696,8 @@ def forward( class GaudiMllamaForConditionalGeneration(MllamaForConditionalGeneration): def __init__(self, config: MllamaConfig): - config._attn_implementation = "eager" + # sdpa is better for vision model in HPU + config._attn_implementation = "sdpa" super(GaudiMllamaForConditionalGeneration, self).__init__(config) def forward( @@ -696,6 +720,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077 @@ -748,9 +773,13 @@ def forward( else: full_text_row_masked_out_mask = None - if cross_attention_mask is not None and cache_position is not None: - cross_attention_mask = cross_attention_mask[:, :, cache_position] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + if cross_attention_mask is not None: + if cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + elif token_idx is not None and past_key_values is not None: + cross_attention_mask = torch.index_select(cross_attention_mask, -2, token_idx - 1) + full_text_row_masked_out_mask = torch.index_select(full_text_row_masked_out_mask, -2, token_idx - 1) outputs = self.language_model( input_ids=input_ids, @@ -808,7 +837,8 @@ def prepare_inputs_for_generation( # for the 1st token we can slice the inputs till token idx for the fwd pass. input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] - cache_position = cache_position[:token_idx] + if cross_attention_mask is not None: + cross_attention_mask = cross_attention_mask[:, :token_idx, ...] # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way if attention_mask is not None and position_ids is None: @@ -834,8 +864,8 @@ def prepare_inputs_for_generation( if num_logits_to_keep is not None: model_inputs["num_logits_to_keep"] = num_logits_to_keep - if token_idx is not None and past_key_values: - cache_position = torch.arange(token_idx - 1, token_idx - 1 + input_ids.shape[1], device=input_ids.device) + # keep cache_position implementation as None for HPU + cache_position = None model_inputs.update( { @@ -875,12 +905,154 @@ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_ # add cross-attn mask for new token if cross_attention_mask_prev is not None: token_idx = model_kwargs.get("token_idx", None) - if token_idx is None: - model_kwargs["cross_attention_mask"] = torch.cat( - [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 - ) - else: - model_kwargs["cross_attention_mask"][:, token_idx - 1, ...] = cross_attention_mask_prev[ - :, token_idx - 2, ... - ].clone() + if token_idx is not None: + mask = cross_attention_mask_prev[:, token_idx - 2 : token_idx - 1, ...] + cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask) + model_kwargs["cross_attention_mask"] = cross_attention_mask_prev return model_kwargs + + +class GaudiMllamaVisionModel(MllamaVisionModel): + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: + """ + Copied from MllamaVisionModel::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1425 + The only differences are: + - optimize perf of stage "Collect intermediate layer outputs from encoder output" + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape + + pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1) + + # Patch embedding + patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device)) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # Tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding(hidden_state, aspect_ratio_ids) + + # Add cls token + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = (0, 0, 0, num_padding_patches) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + # Prepare attention mask + attention_mask = aspect_ratio_mask.reshape(batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + # Apply encoder + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + ) + hidden_state = output[0] + + hidden_state = self.layernorm_post(hidden_state) + + # Apply global encoder + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim + ) + hidden_state = self.post_tile_positional_embedding(hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim + ) + global_output = self.global_transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + hidden_state = global_output[0] + + # Remove padding form hidden state + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim + ) + hidden_state = hidden_state[:, :, :slice_index] + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim) + + # Collect intermediate layer outputs from encoder output + all_intermediate_hidden_states = output[1] + intermediate_hidden_states = [ + hidden_state + for idx, hidden_state in enumerate(all_intermediate_hidden_states) + if idx in self.intermediate_layers_indices + ] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) + + """ + intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) + intermediate_hidden_states = intermediate_hidden_states[..., self.intermediate_layers_indices] + """ + + # Remove padding from intermediate hidden states + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1 + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + + # Concatenate final hidden state and intermediate hidden states + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + + if output_hidden_states: + hidden_states = tuple(all_intermediate_hidden_states) + tuple(global_output[1]) + else: + hidden_states = None + + if output_attentions: + # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range + global_attn = tuple(global_output[2]) if output_hidden_states else tuple(global_output[1]) + attentions = tuple(output[2]) + global_attn + else: + attentions = None + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) From 84e83ddf3ea848693b050ae1117ba014a6f7c1fe Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sat, 19 Oct 2024 06:44:56 -0700 Subject: [PATCH 06/10] refine crossattention logic Signed-off-by: Wang, Yi A --- .../models/mllama/modeling_mllama.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py index 28a61d5901..7f09fccac8 100644 --- a/optimum/habana/transformers/models/mllama/modeling_mllama.py +++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py @@ -777,10 +777,15 @@ def forward( if cache_position is not None: cross_attention_mask = cross_attention_mask[:, :, cache_position] full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] - elif token_idx is not None and past_key_values is not None: - cross_attention_mask = torch.index_select(cross_attention_mask, -2, token_idx - 1) - full_text_row_masked_out_mask = torch.index_select(full_text_row_masked_out_mask, -2, token_idx - 1) - + elif past_key_values is not None: + if token_idx is not None: + cross_attention_mask = torch.index_select(cross_attention_mask, -2, token_idx - 1) + full_text_row_masked_out_mask = torch.index_select( + full_text_row_masked_out_mask, -2, token_idx - 1 + ) + else: + cross_attention_mask = cross_attention_mask[:, :, -1:] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, -1:] outputs = self.language_model( input_ids=input_ids, attention_mask=attention_mask, @@ -895,7 +900,7 @@ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_ - add token_idx handling """ cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) - model_kwargs = super()._update_model_kwargs_for_generation( + model_kwargs = super(MllamaForConditionalGeneration, self)._update_model_kwargs_for_generation( outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, @@ -909,6 +914,10 @@ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_ mask = cross_attention_mask_prev[:, token_idx - 2 : token_idx - 1, ...] cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask) model_kwargs["cross_attention_mask"] = cross_attention_mask_prev + else: + model_kwargs["cross_attention_mask"] = torch.cat( + [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 + ) return model_kwargs From 7c2cd6319f65771f500708582b6042efb8c1e1e3 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sat, 19 Oct 2024 20:29:33 -0700 Subject: [PATCH 07/10] add flash attention support Signed-off-by: Wang, Yi A --- .../habana/transformers/generation/utils.py | 30 ++-- .../models/mllama/modeling_mllama.py | 130 +++++++++++++++--- 2 files changed, 132 insertions(+), 28 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 9e01608024..582e5dd0df 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -458,6 +458,14 @@ def update_model_kwargs_for_bucketing( ) else: assert False, "Not tested for cases where attn_mask isnt passed" + + if model_kwargs.get("cross_attention_mask") is not None: + model_kwargs["cross_attention_mask"] = torch.nn.functional.pad( + model_kwargs["cross_attention_mask"], + (0, 0, 0, 0, 0, pad_amount), + value=0, + ) + if reduce_recompile and params["passnum"] == 0: position_ids_cpu = model_kwargs["attention_mask"].long().cumsum(-1) - 1 position_ids_cpu.masked_fill_(model_kwargs["attention_mask"] == 0, 1) @@ -500,14 +508,20 @@ def create_pad_arg(pad_amount, i, j): # This is a necessary (but not sufficient) condition: what ever dimension we are padding, should be a multiple of bucket_size # This check is added in case we get a new model with a new kv-cache structure, and we attempt to pad some wrong dimension # in peft case, if there's virtual token. the model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size == num_virtual_token, no need of assert, the pad length of past_key_value should be aligned with input id and attention_mask - num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0) - assert ( - model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size - == num_virtual_tokens - ) - tmp_lst[j] = torch.nn.functional.pad( - model_kwargs["past_key_values"][i][j], pad_tuple, value=pad_token_id - ) + if ( + model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] + == params["allocated_space"] - pad_amount + ): + num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0) + assert ( + model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size + == num_virtual_tokens + ) + tmp_lst[j] = torch.nn.functional.pad( + model_kwargs["past_key_values"][i][j], pad_tuple, value=pad_token_id + ) + else: + tmp_lst[j] = model_kwargs["past_key_values"][i][j] new_kv[i] = tuple(tmp_lst) model_kwargs["past_key_values"] = tuple(new_kv) diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py index 7f09fccac8..e5c7ced0d4 100644 --- a/optimum/habana/transformers/models/mllama/modeling_mllama.py +++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py @@ -15,6 +15,7 @@ """PyTorch Mllama model.""" import math +import os from typing import List, Optional, Tuple, Union import torch @@ -51,6 +52,21 @@ logger = logging.get_logger(__name__) +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + + +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale) + def _prepare_cross_attention_mask( cross_attention_mask: torch.Tensor, @@ -92,6 +108,10 @@ def _prepare_cross_attention_mask( class GaudiMllamaTextCrossAttention(MllamaTextCrossAttention): + def __init__(self, config: Optional[MllamaTextConfig] = None, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + def forward( self, hidden_states: torch.Tensor, @@ -102,6 +122,8 @@ def forward( use_cache: bool = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from MllamaTextCrossAttention::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L512 @@ -109,6 +131,7 @@ def forward( - add token_idx support - add support if past_key_value is not Cache - cache position is None + - add use_flash_attention and flash_attention_recompute """ """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() @@ -121,9 +144,9 @@ def forward( value_states = self.v_proj(cross_attention_states) key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - + if not (FusedSDPA and use_flash_attention): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) key_states = self.k_norm(key_states) if past_key_value is not None: # if we have a new image + new tokens, we only computed key_states on that new image @@ -155,15 +178,31 @@ def forward( "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" ) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if FusedSDPA and use_flash_attention: + import habana_frameworks.torch.hpu as ht - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + if q_len == 1: + # next token + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) @@ -175,6 +214,10 @@ def forward( class GaudiMllamaTextSelfAttention(MllamaTextSelfAttention): + def __init__(self, config: Optional[MllamaTextConfig] = None, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + def forward( self, hidden_states: torch.Tensor, @@ -185,6 +228,8 @@ def forward( past_key_value=None, cache_position=None, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, **kwargs, ): """ @@ -192,6 +237,7 @@ def forward( The only differences are: - add token_idx support - add support if past_key_value is not Cache + - add use_flash_attention and flash_attention_recompute """ bsz, q_len, _ = hidden_states.size() @@ -225,19 +271,35 @@ def forward( if use_cache and not isinstance(past_key_value, Cache): past_key_value = [key_states, value_states] - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + if FusedSDPA and use_flash_attention: + import habana_frameworks.torch.hpu as ht + + if q_len == 1: + # next token + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) @@ -270,11 +332,14 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Copied from MllamaSelfAttentionDecoderLayer::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L904 The only differences are: - add token_idx input + - add use_flash_attention and flash_attention_recompute """ residual = hidden_states @@ -291,6 +356,8 @@ def forward( cache_position=cache_position, position_embeddings=position_embeddings, token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, ) hidden_states = residual + hidden_states @@ -330,12 +397,15 @@ def forward( cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, ) -> Tuple[torch.Tensor]: """ Copied from MllamaCrossAttentionDecoderLayer::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L989 The only differences are: - add token_idx support - pass use_cache to cross_attn + - add use_flash_attention and flash_attention_recompute """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -349,6 +419,8 @@ def forward( cache_position=cache_position, use_cache=use_cache, token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, ) hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states @@ -387,12 +459,15 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from MllamaTextModel::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1617 The only differences are: - add token_idx support - add support if past_key_value is not Cache + - add use_flash_attention and flash_attention_recompute """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -506,6 +581,8 @@ def forward( cache_position=cache_position, position_embeddings=position_embeddings, token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, ) hidden_states = layer_outputs[0] @@ -628,12 +705,15 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Copied from MllamaForCausalLM::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1871 The only differences are: - add token_idx input - add logits handle if token_idx is not None + - add use_flash_attention and flash_attention_recompute """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -657,6 +737,8 @@ def forward( return_dict=return_dict, cache_position=cache_position, token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, ) hidden_states = outputs[0] @@ -720,12 +802,15 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077 The only differences are: - add token_idx input + - add use_flash_attention and flash_attention_recompute """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -803,6 +888,8 @@ def forward( cache_position=cache_position, num_logits_to_keep=num_logits_to_keep, token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, ) return outputs @@ -828,6 +915,7 @@ def prepare_inputs_for_generation( The only differences are: - add token_idx handling - add bucket_internal handling + - add use_flash_attention and flash_attention_recompute """ token_idx = kwargs.get("token_idx", None) bucket_internal = kwargs.get("bucket_internal", None) @@ -881,6 +969,8 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, "cross_attention_mask": cross_attention_mask, "token_idx": token_idx, + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), } ) From 57b0bf7056d4ce1a1a23281ef26fddf5b799f254 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 7 Nov 2024 21:11:31 -0800 Subject: [PATCH 08/10] add ci case and doc for mllama Signed-off-by: Wang, Yi A --- README.md | 1 + docs/source/index.mdx | 1 + examples/image-to-text/README.md | 79 +++++++++++++++++++ .../run_image2text_lora_finetune.py | 17 ++-- examples/image-to-text/run_pipeline.py | 13 +-- .../Llama_3_2_11B_Vision_Instruct.json | 38 +++++++++ tests/test_examples.py | 12 +-- tests/test_image_to_text_example.py | 1 + tests/utils.py | 1 + 9 files changed, 139 insertions(+), 24 deletions(-) create mode 100644 tests/baselines/Llama_3_2_11B_Vision_Instruct.json diff --git a/README.md b/README.md index 7c3cc5ccff..442eb9baff 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,7 @@ The following model architectures, tasks and device distributions have been vali | VideoMAE | |
  • Single card
  • |
  • [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)
  • | | TableTransformer | |
  • Single card
  • |
  • [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection)
  • | | DETR | |
  • Single card
  • |
  • [object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection)
  • | +| Mllama |
  • LoRA
  • |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 775f0290ef..06ef4bb13e 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -80,6 +80,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | VideoMAE | |
  • Single card
  • |
  • [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)
  • | | TableTransformer | |
  • Single card
  • |
  • [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection)
  • | | DETR | |
  • Single card
  • |
  • [object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection)
  • | +| Mllama |
  • LoRA
  • |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | - Diffusers diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index b36810f350..5916de4a29 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -31,6 +31,7 @@ Models that have been validated: - [llava-hf/llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf) - [llava-hf/llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llama3-llava-next-8b-hf) - [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b) + - [meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) ### Inference with BF16 @@ -102,6 +103,15 @@ python3 run_pipeline.py \ --bf16 ``` +To run mllama inference, use the following command: + +```bash +python3 run_pipeline.py \ + --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \ + --use_hpu_graphs \ + --bf16 +``` + ### Inference with FP8 Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch. @@ -286,6 +296,75 @@ python3 ../gaudi_spawn.py \ --lora_target_modules '".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"' ``` +Here are single-/multi-device command examples for meta-llama/Llama-3.2-11B-Vision-Instruct. + +```bash +python3 run_image2text_lora_finetune.py \ + --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \ + --dataset_name nielsr/docvqa_1200_examples \ + --bf16 True \ + --output_dir ./model_lora_llama \ + --num_train_epochs 2 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --weight_decay 0.01 \ + --logging_steps 25 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 5e-5 \ + --warmup_steps 50 \ + --lr_scheduler_type "constant" \ + --input_column_names 'image' 'query' \ + --output_column_names 'answers' \ + --remove_unused_columns False \ + --do_train \ + --do_eval \ + --use_habana \ + --use_lazy_mode \ + --lora_rank=8 \ + --lora_alpha=8 \ + --lora_dropout=0.1 \ + --low_cpu_mem_usage True \ + --max_seq_length=512 \ + --use_hpu_graphs_for_inference True \ + --lora_target_modules ".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$" +``` + +```bash +python3 ../gaudi_spawn.py \ + --world_size 8 --use_mpi run_image2text_lora_finetune.py \ + --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \ + --dataset_name nielsr/docvqa_1200_examples \ + --bf16 True \ + --output_dir ./model_lora_llama \ + --num_train_epochs 2 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --weight_decay 0.01 \ + --logging_steps 25 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 5e-5 \ + --warmup_steps 50 \ + --lr_scheduler_type "constant" \ + --input_column_names 'image' 'query' \ + --output_column_names 'answers' \ + --remove_unused_columns False \ + --do_train \ + --do_eval \ + --use_habana \ + --use_lazy_mode \ + --lora_rank=8 \ + --lora_alpha=8 \ + --lora_dropout=0.1 \ + --low_cpu_mem_usage True \ + --max_seq_length=512 \ + --use_hpu_graphs_for_inference True \ + --lora_target_modules '".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"' +``` + ## Multi-HPU inference To enable multi-card inference, you must set the environment variable `PT_HPU_ENABLE_LAZY_COLLECTIVES=true`, diff --git a/examples/image-to-text/run_image2text_lora_finetune.py b/examples/image-to-text/run_image2text_lora_finetune.py index 9ee5af3f77..ded60e6d52 100644 --- a/examples/image-to-text/run_image2text_lora_finetune.py +++ b/examples/image-to-text/run_image2text_lora_finetune.py @@ -251,11 +251,9 @@ class FinetuneArguments: class MyDataCollator: - def __init__(self, processor, max_seq_length): + def __init__(self, processor, max_seq_length, image_token_id): self.processor = processor - self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ - processor.tokenizer.additional_special_tokens.index("") - ] + self.image_token_id = image_token_id self.max_seq_length = max_seq_length def __call__(self, examples): @@ -458,8 +456,15 @@ def main(): if col not in (data_args.input_column_names + data_args.output_column_names) ] ) - - data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length) + if hasattr(config, "image_token_id"): + # idefics + image_token_id = config.image_token_id + elif hasattr(config, "image_token_index"): + # mllama + image_token_id = config.image_token_index + else: + raise ValueError("Please provide value for image_token_id") + data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id) gaudi_config = GaudiConfig() gaudi_config.use_fused_adam = True diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index 670c26e58a..3692278a90 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -262,7 +262,7 @@ def main(): htcore.hpu_initialize(generator.model) # delete once pipeline integrate AutoProcessor as preprocess engine - if model_type in ["idefics2"]: + if model_type in ["idefics2", "mllama"]: from transformers.image_utils import load_image def preprocess(self, image, prompt=None, timeout=None): @@ -272,17 +272,6 @@ def preprocess(self, image, prompt=None, timeout=None): generator.__class__.preprocess = preprocess - # delete once pipeline integrate AutoProcessor as preprocess engine - if model_type in ["mllama"]: - from transformers.image_utils import load_image - - def preprocess(self, image, prompt=None, timeout=None): - image = load_image(image, timeout=timeout) - model_inputs = processor(images=image, text=prompt, return_tensors=self.framework) - return model_inputs - - generator.__class__.preprocess = preprocess - # warm up for i in range(args.warmup): generator(images, prompt=args.prompt, batch_size=args.batch_size, generate_kwargs=generate_kwargs) diff --git a/tests/baselines/Llama_3_2_11B_Vision_Instruct.json b/tests/baselines/Llama_3_2_11B_Vision_Instruct.json new file mode 100644 index 0000000000..3789c63fa9 --- /dev/null +++ b/tests/baselines/Llama_3_2_11B_Vision_Instruct.json @@ -0,0 +1,38 @@ +{ + "gaudi2": { + "image2text_lora_finetune": { + "num_train_epochs": 2, + "eval_batch_size": 4, + "distribution": { + "multi_card": { + "learning_rate": 5e-5, + "train_batch_size": 2, + "train_runtime": 470, + "train_samples_per_second": 22, + "eval_accuracy": 0.6, + "extra_arguments": [ + "--bf16", + "--gradient_accumulation_steps 8", + "--eval_strategy no", + "--save_strategy no", + "--warmup_steps 50", + "--lr_scheduler_type constant", + "--max_grad_norm 0.3", + "--logging_steps 1", + "--use_hpu_graphs_for_inference", + "--lora_rank 8", + "--lora_alpha 8", + "--lora_dropout 0.1", + "--lora_target_modules '.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'", + "--low_cpu_mem_usage True", + "--adam_epsilon 1e-08", + "--input_column_name image query", + "--output_column_name answers", + "--remove_unused_columns False", + "--max_seq_length 512" + ] + } + } + } + } +} diff --git a/tests/test_examples.py b/tests/test_examples.py index f24d250880..f84cdc75c6 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -34,6 +34,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + MODEL_FOR_VISION_2_SEQ_MAPPING, MODEL_MAPPING, ) from transformers.testing_utils import slow @@ -203,8 +204,8 @@ def is_valid_model_type(model_type: str) -> bool: ), "run_image2text_lora_finetune": _get_supported_models_for_script( MODELS_TO_TEST_MAPPING, - MODEL_MAPPING, - ["idefics2"], + MODEL_FOR_VISION_2_SEQ_MAPPING, + ["idefics2", "mllama"], ), } @@ -421,10 +422,9 @@ def test(self): create_clip_roberta_model() self._install_requirements(example_script.parent / "requirements.txt") - - path_to_baseline = BASELINE_DIRECTORY / Path(model_name.split("/")[-1].replace("-", "_")).with_suffix( - ".json" - ) + path_to_baseline = BASELINE_DIRECTORY / Path( + model_name.split("/")[-1].replace("-", "_").replace(".", "_") + ).with_suffix(".json") with path_to_baseline.open("r") as json_file: device = "gaudi2" if IS_GAUDI2 else "gaudi" baseline = json.load(json_file)[device] diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py index 1cb8b95b33..60049bf46e 100644 --- a/tests/test_image_to_text_example.py +++ b/tests/test_image_to_text_example.py @@ -20,6 +20,7 @@ ("llava-hf/llava-v1.6-vicuna-7b-hf", 1, 35.00608681379742), ("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 23.527610042925), ("HuggingFaceM4/idefics2-8b", 1, 21.89944593215077), + ("meta-llama/Llama-3.2-11B-Vision-Instruct", 1, 20.407843538649303), ], "fp8": [ ("llava-hf/llava-1.5-7b-hf", 1, 98.72578382705062), diff --git a/tests/utils.py b/tests/utils.py index 18c00a564c..7eab1b06be 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -62,6 +62,7 @@ "protst": [("mila-intel/protst-esm1b-for-sequential-classification", "Habana/gpt2")], "qwen2": [("Qwen/Qwen2-7B", "Habana/qwen"), ("Qwen/Qwen2-72B", "Habana/qwen")], "idefics2": [("HuggingFaceM4/idefics2-8b", "Habana/gpt2")], + "mllama": [("meta-llama/Llama-3.2-11B-Vision-Instruct", "Habana/gpt2")], } MODELS_TO_TEST_FOR_QUESTION_ANSWERING = [ From 30d275bcbb3a3558d9d10c8847b3cf84e0e7ed6a Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 10 Nov 2024 21:12:28 -0800 Subject: [PATCH 09/10] enable TP for mllama skip TP for Vision part.need https://github.com/microsoft/DeepSpeed/commit/e97b453645a03bc6901d74ec13be1c5d7f1a1fec to enable vision part TP Signed-off-by: Wang, Yi A --- examples/image-to-text/run_pipeline.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index 3692278a90..ebad5998f7 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -76,7 +76,18 @@ def initialize_distributed_model(args, model, logger, model_dtype): ds_inference_kwargs = {"dtype": model_dtype} ds_inference_kwargs["tensor_parallel"] = {"tp_size": args.world_size} ds_inference_kwargs["enable_cuda_graph"] = args.use_hpu_graphs - ds_inference_kwargs["injection_policy"] = {} + if model.config.model_type == "mllama": + from transformers.models.mllama.modeling_mllama import ( + MllamaCrossAttentionDecoderLayer, + MllamaSelfAttentionDecoderLayer, + ) + + ds_inference_kwargs["injection_policy"] = { + MllamaSelfAttentionDecoderLayer: ("self_attn.o_proj", "mlp.down_proj"), + MllamaCrossAttentionDecoderLayer: ("cross_attn.o_proj", "mlp.down_proj"), + } + else: + ds_inference_kwargs["injection_policy"] = {} model = deepspeed.init_inference(model, **ds_inference_kwargs).module From d09beea21b492e7a6d711332f62775db18c5db1a Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 11 Nov 2024 16:43:59 -0800 Subject: [PATCH 10/10] enable mllama 90B TP. Signed-off-by: Wang, Yi A --- README.md | 2 +- docs/source/index.mdx | 2 +- examples/image-to-text/run_pipeline.py | 47 ++++++++++++++------------ 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 442eb9baff..308bcda2c7 100644 --- a/README.md +++ b/README.md @@ -234,7 +234,7 @@ The following model architectures, tasks and device distributions have been vali | VideoMAE | |
  • Single card
  • |
  • [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)
  • | | TableTransformer | |
  • Single card
  • |
  • [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection)
  • | | DETR | |
  • Single card
  • |
  • [object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection)
  • | -| Mllama |
  • LoRA
  • |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | +| Mllama |
  • LoRA
  • | :heavy_check_mark: |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 06ef4bb13e..8fe17ad95a 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -80,7 +80,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | VideoMAE | |
  • Single card
  • |
  • [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)
  • | | TableTransformer | |
  • Single card
  • |
  • [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection)
  • | | DETR | |
  • Single card
  • |
  • [object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection)
  • | -| Mllama |
  • LoRA
  • |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | +| Mllama |
  • LoRA
  • |✅ |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | - Diffusers diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index ebad5998f7..75b391ea2e 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -23,7 +23,7 @@ import PIL.Image import requests import torch -from transformers import AutoConfig, AutoProcessor, pipeline +from transformers import AutoConfig, AutoModelForVision2Seq, AutoProcessor, pipeline from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi @@ -76,18 +76,7 @@ def initialize_distributed_model(args, model, logger, model_dtype): ds_inference_kwargs = {"dtype": model_dtype} ds_inference_kwargs["tensor_parallel"] = {"tp_size": args.world_size} ds_inference_kwargs["enable_cuda_graph"] = args.use_hpu_graphs - if model.config.model_type == "mllama": - from transformers.models.mllama.modeling_mllama import ( - MllamaCrossAttentionDecoderLayer, - MllamaSelfAttentionDecoderLayer, - ) - - ds_inference_kwargs["injection_policy"] = { - MllamaSelfAttentionDecoderLayer: ("self_attn.o_proj", "mlp.down_proj"), - MllamaCrossAttentionDecoderLayer: ("cross_attn.o_proj", "mlp.down_proj"), - } - else: - ds_inference_kwargs["injection_policy"] = {} + ds_inference_kwargs["injection_policy"] = {} model = deepspeed.init_inference(model, **ds_inference_kwargs).module @@ -241,17 +230,31 @@ def main(): htcore.hpu_set_env() - generator = pipeline( - "image-to-text", - model=args.model_name_or_path, - torch_dtype=model_dtype, - device="hpu", - ) - if args.world_size > 1: - generator.model = initialize_distributed_model(args, generator.model, logger, model_dtype) - + import deepspeed + + with deepspeed.OnDevice(dtype=model_dtype, device="cpu"): + model = AutoModelForVision2Seq.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype) + if model_type == "mllama": + model.language_model = initialize_distributed_model(args, model.language_model, logger, model_dtype) + else: + model = initialize_distributed_model(args, model, logger, model_dtype) + generator = pipeline( + "image-to-text", + model=model, + config=args.model_name_or_path, + tokenizer=args.model_name_or_path, + image_processor=args.model_name_or_path, + torch_dtype=model_dtype, + device="hpu", + ) else: + generator = pipeline( + "image-to-text", + model=args.model_name_or_path, + torch_dtype=model_dtype, + device="hpu", + ) if args.use_hpu_graphs: from habana_frameworks.torch.hpu import wrap_in_hpu_graph