diff --git a/README.md b/README.md index 7c3cc5ccff..308bcda2c7 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,7 @@ The following model architectures, tasks and device distributions have been vali | VideoMAE | |
  • Single card
  • |
  • [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)
  • | | TableTransformer | |
  • Single card
  • |
  • [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection)
  • | | DETR | |
  • Single card
  • |
  • [object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection)
  • | +| Mllama |
  • LoRA
  • | :heavy_check_mark: |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 775f0290ef..8fe17ad95a 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -80,6 +80,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | VideoMAE | |
  • Single card
  • |
  • [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification)
  • | | TableTransformer | |
  • Single card
  • |
  • [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection)
  • | | DETR | |
  • Single card
  • |
  • [object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection)
  • | +| Mllama |
  • LoRA
  • |✅ |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | - Diffusers diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index b36810f350..5916de4a29 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -31,6 +31,7 @@ Models that have been validated: - [llava-hf/llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf) - [llava-hf/llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llama3-llava-next-8b-hf) - [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b) + - [meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) ### Inference with BF16 @@ -102,6 +103,15 @@ python3 run_pipeline.py \ --bf16 ``` +To run mllama inference, use the following command: + +```bash +python3 run_pipeline.py \ + --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \ + --use_hpu_graphs \ + --bf16 +``` + ### Inference with FP8 Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch. @@ -286,6 +296,75 @@ python3 ../gaudi_spawn.py \ --lora_target_modules '".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"' ``` +Here are single-/multi-device command examples for meta-llama/Llama-3.2-11B-Vision-Instruct. + +```bash +python3 run_image2text_lora_finetune.py \ + --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \ + --dataset_name nielsr/docvqa_1200_examples \ + --bf16 True \ + --output_dir ./model_lora_llama \ + --num_train_epochs 2 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --weight_decay 0.01 \ + --logging_steps 25 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 5e-5 \ + --warmup_steps 50 \ + --lr_scheduler_type "constant" \ + --input_column_names 'image' 'query' \ + --output_column_names 'answers' \ + --remove_unused_columns False \ + --do_train \ + --do_eval \ + --use_habana \ + --use_lazy_mode \ + --lora_rank=8 \ + --lora_alpha=8 \ + --lora_dropout=0.1 \ + --low_cpu_mem_usage True \ + --max_seq_length=512 \ + --use_hpu_graphs_for_inference True \ + --lora_target_modules ".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$" +``` + +```bash +python3 ../gaudi_spawn.py \ + --world_size 8 --use_mpi run_image2text_lora_finetune.py \ + --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \ + --dataset_name nielsr/docvqa_1200_examples \ + --bf16 True \ + --output_dir ./model_lora_llama \ + --num_train_epochs 2 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --weight_decay 0.01 \ + --logging_steps 25 \ + --eval_strategy "no" \ + --save_strategy "no" \ + --learning_rate 5e-5 \ + --warmup_steps 50 \ + --lr_scheduler_type "constant" \ + --input_column_names 'image' 'query' \ + --output_column_names 'answers' \ + --remove_unused_columns False \ + --do_train \ + --do_eval \ + --use_habana \ + --use_lazy_mode \ + --lora_rank=8 \ + --lora_alpha=8 \ + --lora_dropout=0.1 \ + --low_cpu_mem_usage True \ + --max_seq_length=512 \ + --use_hpu_graphs_for_inference True \ + --lora_target_modules '".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"' +``` + ## Multi-HPU inference To enable multi-card inference, you must set the environment variable `PT_HPU_ENABLE_LAZY_COLLECTIVES=true`, diff --git a/examples/image-to-text/run_image2text_lora_finetune.py b/examples/image-to-text/run_image2text_lora_finetune.py index 9ee5af3f77..ded60e6d52 100644 --- a/examples/image-to-text/run_image2text_lora_finetune.py +++ b/examples/image-to-text/run_image2text_lora_finetune.py @@ -251,11 +251,9 @@ class FinetuneArguments: class MyDataCollator: - def __init__(self, processor, max_seq_length): + def __init__(self, processor, max_seq_length, image_token_id): self.processor = processor - self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ - processor.tokenizer.additional_special_tokens.index("") - ] + self.image_token_id = image_token_id self.max_seq_length = max_seq_length def __call__(self, examples): @@ -458,8 +456,15 @@ def main(): if col not in (data_args.input_column_names + data_args.output_column_names) ] ) - - data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length) + if hasattr(config, "image_token_id"): + # idefics + image_token_id = config.image_token_id + elif hasattr(config, "image_token_index"): + # mllama + image_token_id = config.image_token_index + else: + raise ValueError("Please provide value for image_token_id") + data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id) gaudi_config = GaudiConfig() gaudi_config.use_fused_adam = True diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index 65622a7323..75b391ea2e 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -23,7 +23,7 @@ import PIL.Image import requests import torch -from transformers import AutoConfig, AutoProcessor, pipeline +from transformers import AutoConfig, AutoModelForVision2Seq, AutoProcessor, pipeline from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi @@ -185,14 +185,13 @@ def main(): adapt_transformers_to_gaudi() model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type - if args.image_path is None and model_type in ["llava", "idefics2"]: + if args.image_path is None and model_type in ["llava", "idefics2", "mllama"]: args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"] elif args.image_path is None and model_type == "llava_next": args.image_path = [ "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" ] - - if args.prompt is None and model_type in ["llava", "idefics2", "llava_next"]: + if args.prompt is None and model_type in ["llava", "idefics2", "llava_next", "mllama"]: processor = AutoProcessor.from_pretrained(args.model_name_or_path) conversation = [ { @@ -231,17 +230,31 @@ def main(): htcore.hpu_set_env() - generator = pipeline( - "image-to-text", - model=args.model_name_or_path, - torch_dtype=model_dtype, - device="hpu", - ) - if args.world_size > 1: - generator.model = initialize_distributed_model(args, generator.model, logger, model_dtype) - + import deepspeed + + with deepspeed.OnDevice(dtype=model_dtype, device="cpu"): + model = AutoModelForVision2Seq.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype) + if model_type == "mllama": + model.language_model = initialize_distributed_model(args, model.language_model, logger, model_dtype) + else: + model = initialize_distributed_model(args, model, logger, model_dtype) + generator = pipeline( + "image-to-text", + model=model, + config=args.model_name_or_path, + tokenizer=args.model_name_or_path, + image_processor=args.model_name_or_path, + torch_dtype=model_dtype, + device="hpu", + ) else: + generator = pipeline( + "image-to-text", + model=args.model_name_or_path, + torch_dtype=model_dtype, + device="hpu", + ) if args.use_hpu_graphs: from habana_frameworks.torch.hpu import wrap_in_hpu_graph @@ -263,7 +276,7 @@ def main(): htcore.hpu_initialize(generator.model) # delete once pipeline integrate AutoProcessor as preprocess engine - if model_type in ["idefics2"]: + if model_type in ["idefics2", "mllama"]: from transformers.image_utils import load_image def preprocess(self, image, prompt=None, timeout=None): diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d4ac4f1218..784d13719a 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -108,6 +108,7 @@ "qwen2_moe", "whisper", "idefics2", + "mllama", ] @@ -329,11 +330,13 @@ def _expand_dict_for_generation(dict_to_expand): def _pad_past_key_values(self, model_kwargs): pad_amount = model_kwargs.get("kv_cache_pad_len", 0) + kv_cache_len = model_kwargs.get("kv_cache_len", 0) if model_kwargs["past_key_values"]: if model_kwargs.get("mqa_model", False): for i in range(len(model_kwargs["past_key_values"])): # layer - if torch.is_tensor( - model_kwargs["past_key_values"][i] + if ( + torch.is_tensor(model_kwargs["past_key_values"][i]) + and model_kwargs["past_key_values"][i].shape[-2] == kv_cache_len - pad_amount ): # tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked model_kwargs["past_key_values"][i] = torch.nn.functional.pad( model_kwargs["past_key_values"][i], (0, 0, 0, pad_amount) @@ -343,8 +346,9 @@ def _pad_past_key_values(self, model_kwargs): else: for i in range(len(model_kwargs["past_key_values"])): # layer for j in range(len(model_kwargs["past_key_values"][i])): # k or v - if torch.is_tensor( - model_kwargs["past_key_values"][i][j] + if ( + torch.is_tensor(model_kwargs["past_key_values"][i][j]) + and model_kwargs["past_key_values"][i][j].shape[-2] == kv_cache_len - pad_amount ): # tensor(batch_size, n_heads, kv_cache_len, head_dim) model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad( model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount) @@ -460,6 +464,14 @@ def update_model_kwargs_for_bucketing( ) else: assert False, "Not tested for cases where attn_mask isnt passed" + + if model_kwargs.get("cross_attention_mask") is not None: + model_kwargs["cross_attention_mask"] = torch.nn.functional.pad( + model_kwargs["cross_attention_mask"], + (0, 0, 0, 0, 0, pad_amount), + value=0, + ) + if reduce_recompile and params["passnum"] == 0: position_ids_cpu = model_kwargs["attention_mask"].long().cumsum(-1) - 1 position_ids_cpu.masked_fill_(model_kwargs["attention_mask"] == 0, 1) @@ -502,14 +514,20 @@ def create_pad_arg(pad_amount, i, j): # This is a necessary (but not sufficient) condition: what ever dimension we are padding, should be a multiple of bucket_size # This check is added in case we get a new model with a new kv-cache structure, and we attempt to pad some wrong dimension # in peft case, if there's virtual token. the model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size == num_virtual_token, no need of assert, the pad length of past_key_value should be aligned with input id and attention_mask - num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0) - assert ( - model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size - == num_virtual_tokens - ) - tmp_lst[j] = torch.nn.functional.pad( - model_kwargs["past_key_values"][i][j], pad_tuple, value=pad_token_id - ) + if ( + model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] + == params["allocated_space"] - pad_amount + ): + num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0) + assert ( + model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size + == num_virtual_tokens + ) + tmp_lst[j] = torch.nn.functional.pad( + model_kwargs["past_key_values"][i][j], pad_tuple, value=pad_token_id + ) + else: + tmp_lst[j] = model_kwargs["past_key_values"][i][j] new_kv[i] = tuple(tmp_lst) model_kwargs["past_key_values"] = tuple(new_kv) @@ -1109,6 +1127,12 @@ def generate( (0, generation_config.max_new_tokens), value=0, ) + if model_kwargs.get("cross_attention_mask") is not None: + model_kwargs["cross_attention_mask"] = torch.nn.functional.pad( + model_kwargs["cross_attention_mask"], + (0, 0, 0, 0, 0, generation_config.max_new_tokens), + value=0, + ) else: assert generation_config.bucket_size <= 0, "Untested path for bucket>0" if model_kwargs.get("decoder_input_ids", None) is None: diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 9659cbbd28..d49b25aa42 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -88,6 +88,14 @@ GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, GaudiMixtralModel, + GaudiMllamaCrossAttentionDecoderLayer, + GaudiMllamaForCausalLM, + GaudiMllamaForConditionalGeneration, + GaudiMllamaSelfAttentionDecoderLayer, + GaudiMllamaTextCrossAttention, + GaudiMllamaTextModel, + GaudiMllamaTextSelfAttention, + GaudiMllamaVisionModel, GaudiMptAttention, GaudiMptBlock, GaudiMptForCausalLM, @@ -618,6 +626,16 @@ def adapt_transformers_to_gaudi(): transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration = GaudiWhisperForConditionalGeneration transformers.models.whisper.modeling_whisper.WHISPER_ATTENTION_CLASSES = GAUDI_WHISPER_ATTENTION_CLASSES + # Optimization for mllama on Gaudi + transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer = GaudiMllamaSelfAttentionDecoderLayer + transformers.models.mllama.modeling_mllama.MllamaCrossAttentionDecoderLayer = GaudiMllamaCrossAttentionDecoderLayer + transformers.models.mllama.modeling_mllama.MllamaForCausalLM = GaudiMllamaForCausalLM + transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention = GaudiMllamaTextSelfAttention + transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention = GaudiMllamaTextCrossAttention + transformers.models.mllama.modeling_mllama.MllamaForConditionalGeneration = GaudiMllamaForConditionalGeneration + transformers.models.mllama.modeling_mllama.MllamaTextModel = GaudiMllamaTextModel + transformers.models.mllama.modeling_mllama.MllamaVisionModel = GaudiMllamaVisionModel + transformers.AutoConfig.register("deci", DeciLMConfig) transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM) diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 232cb7522a..51c76140d1 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -150,6 +150,16 @@ gaudi_mixtral_block_sparse_moe_forward, gaudi_mixtral_rmsnorm_forward, ) +from .mllama import ( + GaudiMllamaCrossAttentionDecoderLayer, + GaudiMllamaForCausalLM, + GaudiMllamaForConditionalGeneration, + GaudiMllamaSelfAttentionDecoderLayer, + GaudiMllamaTextCrossAttention, + GaudiMllamaTextModel, + GaudiMllamaTextSelfAttention, + GaudiMllamaVisionModel, +) from .modeling_all_models import ( gaudi_check_and_enable_sdpa, gaudi_conv1d_forward, diff --git a/optimum/habana/transformers/models/mllama/__init__.py b/optimum/habana/transformers/models/mllama/__init__.py new file mode 100644 index 0000000000..198f1cc2aa --- /dev/null +++ b/optimum/habana/transformers/models/mllama/__init__.py @@ -0,0 +1,10 @@ +from .modeling_mllama import ( + GaudiMllamaCrossAttentionDecoderLayer, + GaudiMllamaForCausalLM, + GaudiMllamaForConditionalGeneration, + GaudiMllamaSelfAttentionDecoderLayer, + GaudiMllamaTextCrossAttention, + GaudiMllamaTextModel, + GaudiMllamaTextSelfAttention, + GaudiMllamaVisionModel, +) diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py new file mode 100644 index 0000000000..e5c7ced0d4 --- /dev/null +++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py @@ -0,0 +1,1157 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" + +import math +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig +from transformers.models.mllama.modeling_mllama import ( + MllamaCrossAttentionDecoderLayer, + MllamaForCausalLM, + MllamaForConditionalGeneration, + MllamaSelfAttentionDecoderLayer, + MllamaTextCrossAttention, + MllamaTextModel, + MllamaTextSelfAttention, + MllamaVisionModel, + _prepare_4d_causal_attention_mask_with_cache_position, + _prepare_aspect_ratio_attention_mask, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import ( + logging, +) + +from ...modeling_attn_mask_utils import ( + _gaudi_prepare_4d_causal_attention_mask, +) + + +logger = logging.get_logger(__name__) + +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + + +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale) + + +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + num_vision_tokens: int, + dtype: str, + token_idx: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Copied from _prepare_cross_attention_mask: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L99 + The only differences are: + - if there's pading in cross_attention_mask in the right. do not masked it, or else it will impact softmax in crossattention + """ + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + if token_idx is not None: + full_text_row_masked_out_mask2 = full_text_row_masked_out_mask.clone() + full_text_row_masked_out_mask2[:, :, token_idx:, :] = 1 + cross_attention_mask *= full_text_row_masked_out_mask2 + else: + cross_attention_mask *= full_text_row_masked_out_mask + + return cross_attention_mask, full_text_row_masked_out_mask + + +class GaudiMllamaTextCrossAttention(MllamaTextCrossAttention): + def __init__(self, config: Optional[MllamaTextConfig] = None, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from MllamaTextCrossAttention::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L512 + The only differences are: + - add token_idx support + - add support if past_key_value is not Cache + - cache position is None + - add use_flash_attention and flash_attention_recompute + """ + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_norm(query_states) + + if cross_attention_states is not None: + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if not (FusedSDPA and use_flash_attention): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + key_states = self.k_norm(key_states) + if past_key_value is not None: + # if we have a new image + new tokens, we only computed key_states on that new image + # we still update the cross key states, past_image, new_image. And use it! + if isinstance(past_key_value, Cache): + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + else: + if token_idx is not None: + past_key_value[0].index_copy_(2, token_idx - 1, key_states) + past_key_value[1].index_copy_(2, token_idx - 1, value_states) + key_states = past_key_value[0] + value_states = past_key_value[1] + else: + key_states = torch.cat((past_key_value[0], key_states), dim=2) + value_states = torch.cat((past_key_value[1], value_states), dim=2) + if use_cache and not isinstance(past_key_value, Cache): + past_key_value = [key_states, value_states] + elif not isinstance(past_key_value, Cache) and past_key_value is not None: + key_states, value_states = (past_key_value[0], past_key_value[1]) + elif cache_position is not None and cache_position[0] != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + + if FusedSDPA and use_flash_attention: + import habana_frameworks.torch.hpu as ht + + if q_len == 1: + # next token + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class GaudiMllamaTextSelfAttention(MllamaTextSelfAttention): + def __init__(self, config: Optional[MllamaTextConfig] = None, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + **kwargs, + ): + """ + Copied from MllamaTextSelfAttention::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L733 + The only differences are: + - add token_idx support + - add support if past_key_value is not Cache + - add use_flash_attention and flash_attention_recompute + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + if isinstance(past_key_value, Cache): + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + else: + if token_idx is not None: + past_key_value[0].index_copy_(2, token_idx - 1, key_states) + past_key_value[1].index_copy_(2, token_idx - 1, value_states) + key_states = past_key_value[0] + value_states = past_key_value[1] + else: + key_states = torch.cat((past_key_value[0], key_states), dim=2) + value_states = torch.cat((past_key_value[1], value_states), dim=2) + if use_cache and not isinstance(past_key_value, Cache): + past_key_value = [key_states, value_states] + + if FusedSDPA and use_flash_attention: + import habana_frameworks.torch.hpu as ht + + if q_len == 1: + # next token + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer +class GaudiMllamaSelfAttentionDecoderLayer(MllamaSelfAttentionDecoderLayer): + def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: + super(GaudiMllamaSelfAttentionDecoderLayer, self).__init__(config, layer_idx) + self.self_attn = GaudiMllamaTextSelfAttention(config, layer_idx=layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Copied from MllamaSelfAttentionDecoderLayer::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L904 + The only differences are: + - add token_idx input + - add use_flash_attention and flash_attention_recompute + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class GaudiMllamaCrossAttentionDecoderLayer(MllamaCrossAttentionDecoderLayer): + def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: + super(GaudiMllamaCrossAttentionDecoderLayer, self).__init__(config, layer_idx) + self.cross_attn = GaudiMllamaTextCrossAttention(config, layer_idx=layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + attention_mask: torch.Tensor, + full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + """ + Copied from MllamaCrossAttentionDecoderLayer::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L989 + The only differences are: + - add token_idx support + - pass use_cache to cross_attn + - add use_flash_attention and flash_attention_recompute + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, attn_weights, past_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, + use_cache=use_cache, + token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + ) + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if full_text_row_masked_out_mask is not None: + hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + return outputs + + +class GaudiMllamaTextModel(MllamaTextModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.FloatTensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Copied from MllamaTextModel::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1617 + The only differences are: + - add token_idx support + - add support if past_key_value is not Cache + - add use_flash_attention and flash_attention_recompute + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + if isinstance(past_key_values, Cache): + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + else: + past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ignore_cache_position = True # Ignoring cache position for HPU, or else hpu graph may has issue + if ignore_cache_position is False: + if cache_position is None: + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + else: + if position_ids is None: + position_ids = torch.arange( + past_seen_tokens, + inputs_embeds.shape[1] + past_seen_tokens, + dtype=torch.long, + device=inputs_embeds.device, + ) + position_ids = position_ids.unsqueeze(0) + cache_position = None + causal_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape, + inputs_embeds, + past_seen_tokens, + ) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None if isinstance(past_key_values, Cache) else () + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # For text-only path we should skip cross attention layers. + # Let's check if the layer is cross attention layer and if we have cross attention states + # or cached cross attention states. + is_cross_attention_layer = idx in self.cross_attention_layers + is_cross_attention_cache_empty = past_key_values is None or ( + past_key_values is not None and past_key_values.get_seq_length(idx) == 0 + if isinstance(past_key_values, Cache) + else False + ) + + if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty: + continue + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + cross_attention_states, + cross_attention_mask, + causal_mask, + full_text_row_masked_out_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + if isinstance(past_key_values, Cache): + past_key_value = past_key_values + else: + past_key_value = None if past_key_values is None else past_key_values[idx] + layer_outputs = decoder_layer( + hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + attention_mask=causal_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + if isinstance(past_key_values, Cache): + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + else: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + """ + Copied from MllamaTextModel::_update_causal_mask: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1768 + The only differences are: + - add support if past_key_value is not Cache + """ + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + if isinstance(past_key_values, Cache): + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + else: + past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + # TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line + # self.config._attn_implementation == "sdpa" and + if self.config._attn_implementation == "sdpa" and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class GaudiMllamaForCausalLM(MllamaForCausalLM): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Copied from MllamaForCausalLM::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1871 + The only differences are: + - add token_idx input + - add logits handle if token_idx is not None + - add use_flash_attention and flash_attention_recompute + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + cross_attention_states=cross_attention_states, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + ) + + hidden_states = outputs[0] + + if token_idx is None and num_logits_to_keep != 0: + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + else: + logits = self.lm_head(hidden_states).float() + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class GaudiMllamaForConditionalGeneration(MllamaForConditionalGeneration): + def __init__(self, config: MllamaConfig): + # sdpa is better for vision model in HPU + config._attn_implementation = "sdpa" + super(GaudiMllamaForConditionalGeneration, self).__init__(config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077 + The only differences are: + - add token_idx input + - add use_flash_attention and flash_attention_recompute + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") + # get vision tokens from vision model + vision_outputs = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.hidden_size + ) + + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + token_idx=token_idx, + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None: + if cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + elif past_key_values is not None: + if token_idx is not None: + cross_attention_mask = torch.index_select(cross_attention_mask, -2, token_idx - 1) + full_text_row_masked_out_mask = torch.index_select( + full_text_row_masked_out_mask, -2, token_idx - 1 + ) + else: + cross_attention_mask = cross_attention_mask[:, :, -1:] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, -1:] + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + labels=labels, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + ) + + return outputs + + def prepare_inputs_for_generation( + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + position_ids=None, + pixel_values=None, + aspect_ratio_ids=None, + aspect_ratio_mask=None, + cross_attention_mask=None, + past_key_values=None, + use_cache=False, + cache_position=None, + num_logits_to_keep=None, + **kwargs, + ): + """ + Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208 + The only differences are: + - add token_idx handling + - add bucket_internal handling + - add use_flash_attention and flash_attention_recompute + """ + token_idx = kwargs.get("token_idx", None) + bucket_internal = kwargs.get("bucket_internal", None) + if past_key_values is not None: + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + elif inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + elif bucket_internal and token_idx is not None: + # for the 1st token we can slice the inputs till token idx for the fwd pass. + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] + if cross_attention_mask is not None: + cross_attention_mask = cross_attention_mask[:, :token_idx, ...] + + # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + # keep cache_position implementation as None for HPU + cache_position = None + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "cross_attention_mask": cross_attention_mask, + "token_idx": token_idx, + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + } + ) + + # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios + # to compute image hidden states, otherwise they are cached within each cross attn layer + if (input_ids == self.config.image_token_index).any(): + model_inputs["pixel_values"] = pixel_values + model_inputs["aspect_ratio_ids"] = aspect_ratio_ids + model_inputs["aspect_ratio_mask"] = aspect_ratio_mask + + return model_inputs + + def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): + """ + Copied from MllamaForConditionalGeneration::_update_model_kwargs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2274 + The only differences are: + - add token_idx handling + """ + cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) + model_kwargs = super(MllamaForConditionalGeneration, self)._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + # add cross-attn mask for new token + if cross_attention_mask_prev is not None: + token_idx = model_kwargs.get("token_idx", None) + if token_idx is not None: + mask = cross_attention_mask_prev[:, token_idx - 2 : token_idx - 1, ...] + cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask) + model_kwargs["cross_attention_mask"] = cross_attention_mask_prev + else: + model_kwargs["cross_attention_mask"] = torch.cat( + [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 + ) + return model_kwargs + + +class GaudiMllamaVisionModel(MllamaVisionModel): + def forward( + self, + pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: + """ + Copied from MllamaVisionModel::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1425 + The only differences are: + - optimize perf of stage "Collect intermediate layer outputs from encoder output" + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape + + pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1) + + # Patch embedding + patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device)) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # Tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding(hidden_state, aspect_ratio_ids) + + # Add cls token + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = (0, 0, 0, num_padding_patches) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + # Prepare attention mask + attention_mask = aspect_ratio_mask.reshape(batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + # Apply encoder + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + ) + hidden_state = output[0] + + hidden_state = self.layernorm_post(hidden_state) + + # Apply global encoder + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim + ) + hidden_state = self.post_tile_positional_embedding(hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim + ) + global_output = self.global_transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + hidden_state = global_output[0] + + # Remove padding form hidden state + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim + ) + hidden_state = hidden_state[:, :, :slice_index] + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim) + + # Collect intermediate layer outputs from encoder output + all_intermediate_hidden_states = output[1] + intermediate_hidden_states = [ + hidden_state + for idx, hidden_state in enumerate(all_intermediate_hidden_states) + if idx in self.intermediate_layers_indices + ] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) + + """ + intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1) + intermediate_hidden_states = intermediate_hidden_states[..., self.intermediate_layers_indices] + """ + + # Remove padding from intermediate hidden states + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1 + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + + # Concatenate final hidden state and intermediate hidden states + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + + if output_hidden_states: + hidden_states = tuple(all_intermediate_hidden_states) + tuple(global_output[1]) + else: + hidden_states = None + + if output_attentions: + # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range + global_attn = tuple(global_output[2]) if output_hidden_states else tuple(global_output[1]) + attentions = tuple(output[2]) + global_attn + else: + attentions = None + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) diff --git a/tests/baselines/Llama_3_2_11B_Vision_Instruct.json b/tests/baselines/Llama_3_2_11B_Vision_Instruct.json new file mode 100644 index 0000000000..3789c63fa9 --- /dev/null +++ b/tests/baselines/Llama_3_2_11B_Vision_Instruct.json @@ -0,0 +1,38 @@ +{ + "gaudi2": { + "image2text_lora_finetune": { + "num_train_epochs": 2, + "eval_batch_size": 4, + "distribution": { + "multi_card": { + "learning_rate": 5e-5, + "train_batch_size": 2, + "train_runtime": 470, + "train_samples_per_second": 22, + "eval_accuracy": 0.6, + "extra_arguments": [ + "--bf16", + "--gradient_accumulation_steps 8", + "--eval_strategy no", + "--save_strategy no", + "--warmup_steps 50", + "--lr_scheduler_type constant", + "--max_grad_norm 0.3", + "--logging_steps 1", + "--use_hpu_graphs_for_inference", + "--lora_rank 8", + "--lora_alpha 8", + "--lora_dropout 0.1", + "--lora_target_modules '.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'", + "--low_cpu_mem_usage True", + "--adam_epsilon 1e-08", + "--input_column_name image query", + "--output_column_name answers", + "--remove_unused_columns False", + "--max_seq_length 512" + ] + } + } + } + } +} diff --git a/tests/test_examples.py b/tests/test_examples.py index f24d250880..f84cdc75c6 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -34,6 +34,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + MODEL_FOR_VISION_2_SEQ_MAPPING, MODEL_MAPPING, ) from transformers.testing_utils import slow @@ -203,8 +204,8 @@ def is_valid_model_type(model_type: str) -> bool: ), "run_image2text_lora_finetune": _get_supported_models_for_script( MODELS_TO_TEST_MAPPING, - MODEL_MAPPING, - ["idefics2"], + MODEL_FOR_VISION_2_SEQ_MAPPING, + ["idefics2", "mllama"], ), } @@ -421,10 +422,9 @@ def test(self): create_clip_roberta_model() self._install_requirements(example_script.parent / "requirements.txt") - - path_to_baseline = BASELINE_DIRECTORY / Path(model_name.split("/")[-1].replace("-", "_")).with_suffix( - ".json" - ) + path_to_baseline = BASELINE_DIRECTORY / Path( + model_name.split("/")[-1].replace("-", "_").replace(".", "_") + ).with_suffix(".json") with path_to_baseline.open("r") as json_file: device = "gaudi2" if IS_GAUDI2 else "gaudi" baseline = json.load(json_file)[device] diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py index 1cb8b95b33..60049bf46e 100644 --- a/tests/test_image_to_text_example.py +++ b/tests/test_image_to_text_example.py @@ -20,6 +20,7 @@ ("llava-hf/llava-v1.6-vicuna-7b-hf", 1, 35.00608681379742), ("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 23.527610042925), ("HuggingFaceM4/idefics2-8b", 1, 21.89944593215077), + ("meta-llama/Llama-3.2-11B-Vision-Instruct", 1, 20.407843538649303), ], "fp8": [ ("llava-hf/llava-1.5-7b-hf", 1, 98.72578382705062), diff --git a/tests/utils.py b/tests/utils.py index 18c00a564c..7eab1b06be 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -62,6 +62,7 @@ "protst": [("mila-intel/protst-esm1b-for-sequential-classification", "Habana/gpt2")], "qwen2": [("Qwen/Qwen2-7B", "Habana/qwen"), ("Qwen/Qwen2-72B", "Habana/qwen")], "idefics2": [("HuggingFaceM4/idefics2-8b", "Habana/gpt2")], + "mllama": [("meta-llama/Llama-3.2-11B-Vision-Instruct", "Habana/gpt2")], } MODELS_TO_TEST_FOR_QUESTION_ANSWERING = [