diff --git a/README.md b/README.md
index 7c3cc5ccff..308bcda2c7 100644
--- a/README.md
+++ b/README.md
@@ -234,6 +234,7 @@ The following model architectures, tasks and device distributions have been vali
| VideoMAE | |
Single card | [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification) |
| TableTransformer | | Single card | [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection) |
| DETR | | Single card | [object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection) |
+| Mllama | LoRA | :heavy_check_mark: | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index 775f0290ef..8fe17ad95a 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -80,6 +80,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| VideoMAE | | Single card | [Video classification](https://github.com/huggingface/optimum-habana/tree/main/examples/video-classification) |
| TableTransformer | | Single card | [table object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/table-detection) |
| DETR | | Single card | [object detection](https://github.com/huggingface/optimum-habana/tree/main/examples/object-detection) |
+| Mllama | LoRA |✅ | [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text) |
- Diffusers
diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md
index b36810f350..5916de4a29 100644
--- a/examples/image-to-text/README.md
+++ b/examples/image-to-text/README.md
@@ -31,6 +31,7 @@ Models that have been validated:
- [llava-hf/llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf)
- [llava-hf/llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llama3-llava-next-8b-hf)
- [HuggingFaceM4/idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b)
+ - [meta-llama/Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)
### Inference with BF16
@@ -102,6 +103,15 @@ python3 run_pipeline.py \
--bf16
```
+To run mllama inference, use the following command:
+
+```bash
+python3 run_pipeline.py \
+ --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
+ --use_hpu_graphs \
+ --bf16
+```
+
### Inference with FP8
Inference for Llava-1.5-7b, Llava-1.5-13b, Llava-v1.6-mistral-7b and Llava-v1.6-vicuna-13b in FP8 precision are enabled using [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html), which provides model measurement and quantization capabilities in PyTorch.
@@ -286,6 +296,75 @@ python3 ../gaudi_spawn.py \
--lora_target_modules '".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
```
+Here are single-/multi-device command examples for meta-llama/Llama-3.2-11B-Vision-Instruct.
+
+```bash
+python3 run_image2text_lora_finetune.py \
+ --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
+ --dataset_name nielsr/docvqa_1200_examples \
+ --bf16 True \
+ --output_dir ./model_lora_llama \
+ --num_train_epochs 2 \
+ --per_device_train_batch_size 2 \
+ --per_device_eval_batch_size 2 \
+ --gradient_accumulation_steps 8 \
+ --weight_decay 0.01 \
+ --logging_steps 25 \
+ --eval_strategy "no" \
+ --save_strategy "no" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --lr_scheduler_type "constant" \
+ --input_column_names 'image' 'query' \
+ --output_column_names 'answers' \
+ --remove_unused_columns False \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --lora_rank=8 \
+ --lora_alpha=8 \
+ --lora_dropout=0.1 \
+ --low_cpu_mem_usage True \
+ --max_seq_length=512 \
+ --use_hpu_graphs_for_inference True \
+ --lora_target_modules ".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"
+```
+
+```bash
+python3 ../gaudi_spawn.py \
+ --world_size 8 --use_mpi run_image2text_lora_finetune.py \
+ --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct \
+ --dataset_name nielsr/docvqa_1200_examples \
+ --bf16 True \
+ --output_dir ./model_lora_llama \
+ --num_train_epochs 2 \
+ --per_device_train_batch_size 2 \
+ --per_device_eval_batch_size 2 \
+ --gradient_accumulation_steps 8 \
+ --weight_decay 0.01 \
+ --logging_steps 25 \
+ --eval_strategy "no" \
+ --save_strategy "no" \
+ --learning_rate 5e-5 \
+ --warmup_steps 50 \
+ --lr_scheduler_type "constant" \
+ --input_column_names 'image' 'query' \
+ --output_column_names 'answers' \
+ --remove_unused_columns False \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --lora_rank=8 \
+ --lora_alpha=8 \
+ --lora_dropout=0.1 \
+ --low_cpu_mem_usage True \
+ --max_seq_length=512 \
+ --use_hpu_graphs_for_inference True \
+ --lora_target_modules '".*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$"'
+```
+
## Multi-HPU inference
To enable multi-card inference, you must set the environment variable `PT_HPU_ENABLE_LAZY_COLLECTIVES=true`,
diff --git a/examples/image-to-text/run_image2text_lora_finetune.py b/examples/image-to-text/run_image2text_lora_finetune.py
index 9ee5af3f77..ded60e6d52 100644
--- a/examples/image-to-text/run_image2text_lora_finetune.py
+++ b/examples/image-to-text/run_image2text_lora_finetune.py
@@ -251,11 +251,9 @@ class FinetuneArguments:
class MyDataCollator:
- def __init__(self, processor, max_seq_length):
+ def __init__(self, processor, max_seq_length, image_token_id):
self.processor = processor
- self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
- processor.tokenizer.additional_special_tokens.index("")
- ]
+ self.image_token_id = image_token_id
self.max_seq_length = max_seq_length
def __call__(self, examples):
@@ -458,8 +456,15 @@ def main():
if col not in (data_args.input_column_names + data_args.output_column_names)
]
)
-
- data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length)
+ if hasattr(config, "image_token_id"):
+ # idefics
+ image_token_id = config.image_token_id
+ elif hasattr(config, "image_token_index"):
+ # mllama
+ image_token_id = config.image_token_index
+ else:
+ raise ValueError("Please provide value for image_token_id")
+ data_collator = MyDataCollator(processor, max_seq_length=data_args.max_seq_length, image_token_id=image_token_id)
gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py
index 65622a7323..75b391ea2e 100644
--- a/examples/image-to-text/run_pipeline.py
+++ b/examples/image-to-text/run_pipeline.py
@@ -23,7 +23,7 @@
import PIL.Image
import requests
import torch
-from transformers import AutoConfig, AutoProcessor, pipeline
+from transformers import AutoConfig, AutoModelForVision2Seq, AutoProcessor, pipeline
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
@@ -185,14 +185,13 @@ def main():
adapt_transformers_to_gaudi()
model_type = AutoConfig.from_pretrained(args.model_name_or_path).model_type
- if args.image_path is None and model_type in ["llava", "idefics2"]:
+ if args.image_path is None and model_type in ["llava", "idefics2", "mllama"]:
args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"]
elif args.image_path is None and model_type == "llava_next":
args.image_path = [
"https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
]
-
- if args.prompt is None and model_type in ["llava", "idefics2", "llava_next"]:
+ if args.prompt is None and model_type in ["llava", "idefics2", "llava_next", "mllama"]:
processor = AutoProcessor.from_pretrained(args.model_name_or_path)
conversation = [
{
@@ -231,17 +230,31 @@ def main():
htcore.hpu_set_env()
- generator = pipeline(
- "image-to-text",
- model=args.model_name_or_path,
- torch_dtype=model_dtype,
- device="hpu",
- )
-
if args.world_size > 1:
- generator.model = initialize_distributed_model(args, generator.model, logger, model_dtype)
-
+ import deepspeed
+
+ with deepspeed.OnDevice(dtype=model_dtype, device="cpu"):
+ model = AutoModelForVision2Seq.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype)
+ if model_type == "mllama":
+ model.language_model = initialize_distributed_model(args, model.language_model, logger, model_dtype)
+ else:
+ model = initialize_distributed_model(args, model, logger, model_dtype)
+ generator = pipeline(
+ "image-to-text",
+ model=model,
+ config=args.model_name_or_path,
+ tokenizer=args.model_name_or_path,
+ image_processor=args.model_name_or_path,
+ torch_dtype=model_dtype,
+ device="hpu",
+ )
else:
+ generator = pipeline(
+ "image-to-text",
+ model=args.model_name_or_path,
+ torch_dtype=model_dtype,
+ device="hpu",
+ )
if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
@@ -263,7 +276,7 @@ def main():
htcore.hpu_initialize(generator.model)
# delete once pipeline integrate AutoProcessor as preprocess engine
- if model_type in ["idefics2"]:
+ if model_type in ["idefics2", "mllama"]:
from transformers.image_utils import load_image
def preprocess(self, image, prompt=None, timeout=None):
diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py
index d4ac4f1218..784d13719a 100644
--- a/optimum/habana/transformers/generation/utils.py
+++ b/optimum/habana/transformers/generation/utils.py
@@ -108,6 +108,7 @@
"qwen2_moe",
"whisper",
"idefics2",
+ "mllama",
]
@@ -329,11 +330,13 @@ def _expand_dict_for_generation(dict_to_expand):
def _pad_past_key_values(self, model_kwargs):
pad_amount = model_kwargs.get("kv_cache_pad_len", 0)
+ kv_cache_len = model_kwargs.get("kv_cache_len", 0)
if model_kwargs["past_key_values"]:
if model_kwargs.get("mqa_model", False):
for i in range(len(model_kwargs["past_key_values"])): # layer
- if torch.is_tensor(
- model_kwargs["past_key_values"][i]
+ if (
+ torch.is_tensor(model_kwargs["past_key_values"][i])
+ and model_kwargs["past_key_values"][i].shape[-2] == kv_cache_len - pad_amount
): # tensor(batch_size, kv_cache_len, n_heads * head_dim * 2) k and v stacked
model_kwargs["past_key_values"][i] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i], (0, 0, 0, pad_amount)
@@ -343,8 +346,9 @@ def _pad_past_key_values(self, model_kwargs):
else:
for i in range(len(model_kwargs["past_key_values"])): # layer
for j in range(len(model_kwargs["past_key_values"][i])): # k or v
- if torch.is_tensor(
- model_kwargs["past_key_values"][i][j]
+ if (
+ torch.is_tensor(model_kwargs["past_key_values"][i][j])
+ and model_kwargs["past_key_values"][i][j].shape[-2] == kv_cache_len - pad_amount
): # tensor(batch_size, n_heads, kv_cache_len, head_dim)
model_kwargs["past_key_values"][i][j] = torch.nn.functional.pad(
model_kwargs["past_key_values"][i][j], (0, 0, 0, pad_amount)
@@ -460,6 +464,14 @@ def update_model_kwargs_for_bucketing(
)
else:
assert False, "Not tested for cases where attn_mask isnt passed"
+
+ if model_kwargs.get("cross_attention_mask") is not None:
+ model_kwargs["cross_attention_mask"] = torch.nn.functional.pad(
+ model_kwargs["cross_attention_mask"],
+ (0, 0, 0, 0, 0, pad_amount),
+ value=0,
+ )
+
if reduce_recompile and params["passnum"] == 0:
position_ids_cpu = model_kwargs["attention_mask"].long().cumsum(-1) - 1
position_ids_cpu.masked_fill_(model_kwargs["attention_mask"] == 0, 1)
@@ -502,14 +514,20 @@ def create_pad_arg(pad_amount, i, j):
# This is a necessary (but not sufficient) condition: what ever dimension we are padding, should be a multiple of bucket_size
# This check is added in case we get a new model with a new kv-cache structure, and we attempt to pad some wrong dimension
# in peft case, if there's virtual token. the model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size == num_virtual_token, no need of assert, the pad length of past_key_value should be aligned with input id and attention_mask
- num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0)
- assert (
- model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size
- == num_virtual_tokens
- )
- tmp_lst[j] = torch.nn.functional.pad(
- model_kwargs["past_key_values"][i][j], pad_tuple, value=pad_token_id
- )
+ if (
+ model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)]
+ == params["allocated_space"] - pad_amount
+ ):
+ num_virtual_tokens = model_kwargs.get("num_virtual_tokens", 0)
+ assert (
+ model_kwargs["past_key_values"][i][j].shape[-(len(pad_tuple) // 2)] % bucket_size
+ == num_virtual_tokens
+ )
+ tmp_lst[j] = torch.nn.functional.pad(
+ model_kwargs["past_key_values"][i][j], pad_tuple, value=pad_token_id
+ )
+ else:
+ tmp_lst[j] = model_kwargs["past_key_values"][i][j]
new_kv[i] = tuple(tmp_lst)
model_kwargs["past_key_values"] = tuple(new_kv)
@@ -1109,6 +1127,12 @@ def generate(
(0, generation_config.max_new_tokens),
value=0,
)
+ if model_kwargs.get("cross_attention_mask") is not None:
+ model_kwargs["cross_attention_mask"] = torch.nn.functional.pad(
+ model_kwargs["cross_attention_mask"],
+ (0, 0, 0, 0, 0, generation_config.max_new_tokens),
+ value=0,
+ )
else:
assert generation_config.bucket_size <= 0, "Untested path for bucket>0"
if model_kwargs.get("decoder_input_ids", None) is None:
diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py
index 9659cbbd28..d49b25aa42 100644
--- a/optimum/habana/transformers/modeling_utils.py
+++ b/optimum/habana/transformers/modeling_utils.py
@@ -88,6 +88,14 @@
GaudiMixtralDecoderLayer,
GaudiMixtralForCausalLM,
GaudiMixtralModel,
+ GaudiMllamaCrossAttentionDecoderLayer,
+ GaudiMllamaForCausalLM,
+ GaudiMllamaForConditionalGeneration,
+ GaudiMllamaSelfAttentionDecoderLayer,
+ GaudiMllamaTextCrossAttention,
+ GaudiMllamaTextModel,
+ GaudiMllamaTextSelfAttention,
+ GaudiMllamaVisionModel,
GaudiMptAttention,
GaudiMptBlock,
GaudiMptForCausalLM,
@@ -618,6 +626,16 @@ def adapt_transformers_to_gaudi():
transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration = GaudiWhisperForConditionalGeneration
transformers.models.whisper.modeling_whisper.WHISPER_ATTENTION_CLASSES = GAUDI_WHISPER_ATTENTION_CLASSES
+ # Optimization for mllama on Gaudi
+ transformers.models.mllama.modeling_mllama.MllamaSelfAttentionDecoderLayer = GaudiMllamaSelfAttentionDecoderLayer
+ transformers.models.mllama.modeling_mllama.MllamaCrossAttentionDecoderLayer = GaudiMllamaCrossAttentionDecoderLayer
+ transformers.models.mllama.modeling_mllama.MllamaForCausalLM = GaudiMllamaForCausalLM
+ transformers.models.mllama.modeling_mllama.MllamaTextSelfAttention = GaudiMllamaTextSelfAttention
+ transformers.models.mllama.modeling_mllama.MllamaTextCrossAttention = GaudiMllamaTextCrossAttention
+ transformers.models.mllama.modeling_mllama.MllamaForConditionalGeneration = GaudiMllamaForConditionalGeneration
+ transformers.models.mllama.modeling_mllama.MllamaTextModel = GaudiMllamaTextModel
+ transformers.models.mllama.modeling_mllama.MllamaVisionModel = GaudiMllamaVisionModel
+
transformers.AutoConfig.register("deci", DeciLMConfig)
transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM)
diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py
index 232cb7522a..51c76140d1 100644
--- a/optimum/habana/transformers/models/__init__.py
+++ b/optimum/habana/transformers/models/__init__.py
@@ -150,6 +150,16 @@
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_rmsnorm_forward,
)
+from .mllama import (
+ GaudiMllamaCrossAttentionDecoderLayer,
+ GaudiMllamaForCausalLM,
+ GaudiMllamaForConditionalGeneration,
+ GaudiMllamaSelfAttentionDecoderLayer,
+ GaudiMllamaTextCrossAttention,
+ GaudiMllamaTextModel,
+ GaudiMllamaTextSelfAttention,
+ GaudiMllamaVisionModel,
+)
from .modeling_all_models import (
gaudi_check_and_enable_sdpa,
gaudi_conv1d_forward,
diff --git a/optimum/habana/transformers/models/mllama/__init__.py b/optimum/habana/transformers/models/mllama/__init__.py
new file mode 100644
index 0000000000..198f1cc2aa
--- /dev/null
+++ b/optimum/habana/transformers/models/mllama/__init__.py
@@ -0,0 +1,10 @@
+from .modeling_mllama import (
+ GaudiMllamaCrossAttentionDecoderLayer,
+ GaudiMllamaForCausalLM,
+ GaudiMllamaForConditionalGeneration,
+ GaudiMllamaSelfAttentionDecoderLayer,
+ GaudiMllamaTextCrossAttention,
+ GaudiMllamaTextModel,
+ GaudiMllamaTextSelfAttention,
+ GaudiMllamaVisionModel,
+)
diff --git a/optimum/habana/transformers/models/mllama/modeling_mllama.py b/optimum/habana/transformers/models/mllama/modeling_mllama.py
new file mode 100644
index 0000000000..e5c7ced0d4
--- /dev/null
+++ b/optimum/habana/transformers/models/mllama/modeling_mllama.py
@@ -0,0 +1,1157 @@
+# coding=utf-8
+# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Mllama model."""
+
+import math
+import os
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers.cache_utils import Cache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig
+from transformers.models.mllama.modeling_mllama import (
+ MllamaCrossAttentionDecoderLayer,
+ MllamaForCausalLM,
+ MllamaForConditionalGeneration,
+ MllamaSelfAttentionDecoderLayer,
+ MllamaTextCrossAttention,
+ MllamaTextModel,
+ MllamaTextSelfAttention,
+ MllamaVisionModel,
+ _prepare_4d_causal_attention_mask_with_cache_position,
+ _prepare_aspect_ratio_attention_mask,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+from transformers.utils import (
+ logging,
+)
+
+from ...modeling_attn_mask_utils import (
+ _gaudi_prepare_4d_causal_attention_mask,
+)
+
+
+logger = logging.get_logger(__name__)
+
+try:
+ from habana_frameworks.torch.hpex.kernels import FusedSDPA
+except ImportError:
+ print("Not using HPU fused scaled dot-product attention kernel.")
+ FusedSDPA = None
+
+
+class ModuleFusedSDPA(torch.nn.Module):
+ def __init__(self, fusedSDPA):
+ super().__init__()
+ self._hpu_kernel_fsdpa = fusedSDPA
+
+ def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale):
+ return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale)
+
+
+def _prepare_cross_attention_mask(
+ cross_attention_mask: torch.Tensor,
+ num_vision_tokens: int,
+ dtype: str,
+ token_idx: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Copied from _prepare_cross_attention_mask: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L99
+ The only differences are:
+ - if there's pading in cross_attention_mask in the right. do not masked it, or else it will impact softmax in crossattention
+ """
+ # reshape so it can be used by attn module
+ batch_size, text_total_length, *_ = cross_attention_mask.shape
+ cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3)
+ cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
+ cross_attention_mask = cross_attention_mask.unsqueeze(1)
+
+ # invert the mask
+ inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
+ cross_attention_mask = inverted_cross_attn_mask.masked_fill(
+ inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
+ )
+
+ # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
+ # last dimension contains negative infinity values, otherwise it's 1
+ negative_inf_value = torch.finfo(dtype).min
+ full_text_row_masked_out_mask = (
+ (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
+ )
+ if token_idx is not None:
+ full_text_row_masked_out_mask2 = full_text_row_masked_out_mask.clone()
+ full_text_row_masked_out_mask2[:, :, token_idx:, :] = 1
+ cross_attention_mask *= full_text_row_masked_out_mask2
+ else:
+ cross_attention_mask *= full_text_row_masked_out_mask
+
+ return cross_attention_mask, full_text_row_masked_out_mask
+
+
+class GaudiMllamaTextCrossAttention(MllamaTextCrossAttention):
+ def __init__(self, config: Optional[MllamaTextConfig] = None, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ use_cache: bool = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ Copied from MllamaTextCrossAttention::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L512
+ The only differences are:
+ - add token_idx support
+ - add support if past_key_value is not Cache
+ - cache position is None
+ - add use_flash_attention and flash_attention_recompute
+ """
+ """Input shape: Batch x Time x Channel"""
+ bsz, q_len, _ = hidden_states.size()
+ query_states = self.q_proj(hidden_states)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ query_states = self.q_norm(query_states)
+
+ if cross_attention_states is not None:
+ key_states = self.k_proj(cross_attention_states)
+ value_states = self.v_proj(cross_attention_states)
+ key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ if not (FusedSDPA and use_flash_attention):
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ key_states = self.k_norm(key_states)
+ if past_key_value is not None:
+ # if we have a new image + new tokens, we only computed key_states on that new image
+ # we still update the cross key states, past_image, new_image. And use it!
+ if isinstance(past_key_value, Cache):
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ else:
+ if token_idx is not None:
+ past_key_value[0].index_copy_(2, token_idx - 1, key_states)
+ past_key_value[1].index_copy_(2, token_idx - 1, value_states)
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ else:
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
+ if use_cache and not isinstance(past_key_value, Cache):
+ past_key_value = [key_states, value_states]
+ elif not isinstance(past_key_value, Cache) and past_key_value is not None:
+ key_states, value_states = (past_key_value[0], past_key_value[1])
+ elif cache_position is not None and cache_position[0] != 0:
+ key_states, value_states = (
+ past_key_value.key_cache[self.layer_idx],
+ past_key_value.value_cache[self.layer_idx],
+ )
+ else:
+ raise ValueError(
+ "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
+ )
+
+ if FusedSDPA and use_flash_attention:
+ import habana_frameworks.torch.hpu as ht
+
+ if q_len == 1:
+ # next token
+ use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
+ with ht.sdp_kernel(enable_recompute=use_recompute):
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_states, key_states, value_states, attention_mask, 0.0, False, None
+ )
+ else:
+ with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_states, key_states, value_states, attention_mask, 0.0, False, None
+ )
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class GaudiMllamaTextSelfAttention(MllamaTextSelfAttention):
+ def __init__(self, config: Optional[MllamaTextConfig] = None, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+ self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_embeddings: torch.Tensor,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ past_key_value=None,
+ cache_position=None,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ **kwargs,
+ ):
+ """
+ Copied from MllamaTextSelfAttention::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L733
+ The only differences are:
+ - add token_idx support
+ - add support if past_key_value is not Cache
+ - add use_flash_attention and flash_attention_recompute
+ """
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ if isinstance(past_key_value, Cache):
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs
+ )
+ else:
+ if token_idx is not None:
+ past_key_value[0].index_copy_(2, token_idx - 1, key_states)
+ past_key_value[1].index_copy_(2, token_idx - 1, value_states)
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ else:
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
+ if use_cache and not isinstance(past_key_value, Cache):
+ past_key_value = [key_states, value_states]
+
+ if FusedSDPA and use_flash_attention:
+ import habana_frameworks.torch.hpu as ht
+
+ if q_len == 1:
+ # next token
+ use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
+ with ht.sdp_kernel(enable_recompute=use_recompute):
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_states, key_states, value_states, attention_mask, 0.0, False, None
+ )
+ else:
+ with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
+ attn_output = self.fused_scaled_dot_product_attention(
+ query_states, key_states, value_states, attention_mask, 0.0, False, None
+ )
+ else:
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer
+class GaudiMllamaSelfAttentionDecoderLayer(MllamaSelfAttentionDecoderLayer):
+ def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None:
+ super(GaudiMllamaSelfAttentionDecoderLayer, self).__init__(config, layer_idx)
+ self.self_attn = GaudiMllamaTextSelfAttention(config, layer_idx=layer_idx)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ cross_attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Copied from MllamaSelfAttentionDecoderLayer::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L904
+ The only differences are:
+ - add token_idx input
+ - add use_flash_attention and flash_attention_recompute
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ token_idx=token_idx,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class GaudiMllamaCrossAttentionDecoderLayer(MllamaCrossAttentionDecoderLayer):
+ def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None:
+ super(GaudiMllamaCrossAttentionDecoderLayer, self).__init__(config, layer_idx)
+ self.cross_attn = GaudiMllamaTextCrossAttention(config, layer_idx=layer_idx)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cross_attention_states: torch.Tensor,
+ cross_attention_mask: torch.Tensor,
+ attention_mask: torch.Tensor,
+ full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor],
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[torch.Tensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ """
+ Copied from MllamaCrossAttentionDecoderLayer::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L989
+ The only differences are:
+ - add token_idx support
+ - pass use_cache to cross_attn
+ - add use_flash_attention and flash_attention_recompute
+ """
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states, attn_weights, past_key_value = self.cross_attn(
+ hidden_states=hidden_states,
+ attention_mask=cross_attention_mask,
+ cross_attention_states=cross_attention_states,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ use_cache=use_cache,
+ token_idx=token_idx,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ )
+ hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ if full_text_row_masked_out_mask is not None:
+ hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore
+ hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ if use_cache:
+ outputs += (past_key_value,)
+
+ return outputs
+
+
+class GaudiMllamaTextModel(MllamaTextModel):
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ cross_attention_states: Optional[torch.FloatTensor] = None,
+ cross_attention_mask: Optional[torch.Tensor] = None,
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ """
+ Copied from MllamaTextModel::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1617
+ The only differences are:
+ - add token_idx support
+ - add support if past_key_value is not Cache
+ - add use_flash_attention and flash_attention_recompute
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = inputs_embeds
+ if isinstance(past_key_values, Cache):
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ else:
+ past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+ ignore_cache_position = True # Ignoring cache position for HPU, or else hpu graph may has issue
+ if ignore_cache_position is False:
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+ else:
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_seen_tokens,
+ inputs_embeds.shape[1] + past_seen_tokens,
+ dtype=torch.long,
+ device=inputs_embeds.device,
+ )
+ position_ids = position_ids.unsqueeze(0)
+ cache_position = None
+ causal_mask = _gaudi_prepare_4d_causal_attention_mask(
+ attention_mask,
+ input_ids.shape,
+ inputs_embeds,
+ past_seen_tokens,
+ )
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None if isinstance(past_key_values, Cache) else ()
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # For text-only path we should skip cross attention layers.
+ # Let's check if the layer is cross attention layer and if we have cross attention states
+ # or cached cross attention states.
+ is_cross_attention_layer = idx in self.cross_attention_layers
+ is_cross_attention_cache_empty = past_key_values is None or (
+ past_key_values is not None and past_key_values.get_seq_length(idx) == 0
+ if isinstance(past_key_values, Cache)
+ else False
+ )
+
+ if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty:
+ continue
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ cross_attention_states,
+ cross_attention_mask,
+ causal_mask,
+ full_text_row_masked_out_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ if isinstance(past_key_values, Cache):
+ past_key_value = past_key_values
+ else:
+ past_key_value = None if past_key_values is None else past_key_values[idx]
+ layer_outputs = decoder_layer(
+ hidden_states,
+ cross_attention_states=cross_attention_states,
+ cross_attention_mask=cross_attention_mask,
+ attention_mask=causal_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ token_idx=token_idx,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ if isinstance(past_key_values, Cache):
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+ else:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+ ):
+ """
+ Copied from MllamaTextModel::_update_causal_mask: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1768
+ The only differences are:
+ - add support if past_key_value is not Cache
+ """
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ if isinstance(past_key_values, Cache):
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ else:
+ past_seen_tokens = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ # TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line
+ # self.config._attn_implementation == "sdpa" and
+ if self.config._attn_implementation == "sdpa" and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ min_dtype=min_dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type == "cuda"
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+
+class GaudiMllamaForCausalLM(MllamaForCausalLM):
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ cross_attention_states: Optional[torch.LongTensor] = None,
+ cross_attention_mask: Optional[torch.LongTensor] = None,
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: int = 0,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ """
+ Copied from MllamaForCausalLM::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1871
+ The only differences are:
+ - add token_idx input
+ - add logits handle if token_idx is not None
+ - add use_flash_attention and flash_attention_recompute
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ cross_attention_states=cross_attention_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cross_attention_mask=cross_attention_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ token_idx=token_idx,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ )
+
+ hidden_states = outputs[0]
+
+ if token_idx is None and num_logits_to_keep != 0:
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
+ else:
+ logits = self.lm_head(hidden_states).float()
+
+ loss = None
+ if labels is not None:
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class GaudiMllamaForConditionalGeneration(MllamaForConditionalGeneration):
+ def __init__(self, config: MllamaConfig):
+ # sdpa is better for vision model in HPU
+ config._attn_implementation = "sdpa"
+ super(GaudiMllamaForConditionalGeneration, self).__init__(config)
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ aspect_ratio_mask: Optional[torch.Tensor] = None,
+ aspect_ratio_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_states: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: int = 0,
+ token_idx: Optional[torch.Tensor] = None,
+ use_flash_attention: Optional[bool] = False,
+ flash_attention_recompute: Optional[bool] = False,
+ **kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ """
+ Copied from MllamaForConditionalGeneration::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2077
+ The only differences are:
+ - add token_idx input
+ - add use_flash_attention and flash_attention_recompute
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and cross_attention_states is not None:
+ raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
+
+ if pixel_values is not None:
+ if aspect_ratio_ids is None:
+ raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
+ # get vision tokens from vision model
+ vision_outputs = self.vision_model(
+ pixel_values=pixel_values,
+ aspect_ratio_ids=aspect_ratio_ids,
+ aspect_ratio_mask=aspect_ratio_mask,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+ cross_attention_states = vision_outputs[0]
+ cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
+ -1, cross_attention_states.shape[-2], self.hidden_size
+ )
+
+ if cross_attention_mask is not None:
+ cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
+ cross_attention_mask,
+ num_vision_tokens=self.vision_model.num_patches,
+ dtype=self.dtype,
+ token_idx=token_idx,
+ )
+ else:
+ full_text_row_masked_out_mask = None
+
+ if cross_attention_mask is not None:
+ if cache_position is not None:
+ cross_attention_mask = cross_attention_mask[:, :, cache_position]
+ full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
+ elif past_key_values is not None:
+ if token_idx is not None:
+ cross_attention_mask = torch.index_select(cross_attention_mask, -2, token_idx - 1)
+ full_text_row_masked_out_mask = torch.index_select(
+ full_text_row_masked_out_mask, -2, token_idx - 1
+ )
+ else:
+ cross_attention_mask = cross_attention_mask[:, :, -1:]
+ full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, -1:]
+ outputs = self.language_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cross_attention_states=cross_attention_states,
+ cross_attention_mask=cross_attention_mask,
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ num_logits_to_keep=num_logits_to_keep,
+ token_idx=token_idx,
+ use_flash_attention=use_flash_attention,
+ flash_attention_recompute=flash_attention_recompute,
+ )
+
+ return outputs
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids=None,
+ inputs_embeds=None,
+ attention_mask=None,
+ position_ids=None,
+ pixel_values=None,
+ aspect_ratio_ids=None,
+ aspect_ratio_mask=None,
+ cross_attention_mask=None,
+ past_key_values=None,
+ use_cache=False,
+ cache_position=None,
+ num_logits_to_keep=None,
+ **kwargs,
+ ):
+ """
+ Copied from MllamaForConditionalGeneration::prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2208
+ The only differences are:
+ - add token_idx handling
+ - add bucket_internal handling
+ - add use_flash_attention and flash_attention_recompute
+ """
+ token_idx = kwargs.get("token_idx", None)
+ bucket_internal = kwargs.get("bucket_internal", None)
+ if past_key_values is not None:
+ if token_idx is not None:
+ input_ids = torch.index_select(input_ids, 1, token_idx - 1)
+ elif inputs_embeds is not None: # Exception 1
+ input_ids = input_ids[:, -cache_position.shape[0] :]
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
+ input_ids = input_ids[:, cache_position]
+ elif bucket_internal and token_idx is not None:
+ # for the 1st token we can slice the inputs till token idx for the fwd pass.
+ input_ids = input_ids[:, :token_idx]
+ attention_mask = attention_mask[:, :token_idx]
+ if cross_attention_mask is not None:
+ cross_attention_mask = cross_attention_mask[:, :token_idx, ...]
+
+ # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ if token_idx is not None:
+ position_ids = torch.index_select(position_ids, 1, token_idx - 1)
+ else:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and cache_position[0] == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
+ else:
+ # The clone here is for the same reason as for `position_ids`.
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
+
+ if num_logits_to_keep is not None:
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
+
+ # keep cache_position implementation as None for HPU
+ cache_position = None
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ "cross_attention_mask": cross_attention_mask,
+ "token_idx": token_idx,
+ "use_flash_attention": kwargs.get("use_flash_attention"),
+ "flash_attention_recompute": kwargs.get("flash_attention_recompute"),
+ }
+ )
+
+ # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
+ # to compute image hidden states, otherwise they are cached within each cross attn layer
+ if (input_ids == self.config.image_token_index).any():
+ model_inputs["pixel_values"] = pixel_values
+ model_inputs["aspect_ratio_ids"] = aspect_ratio_ids
+ model_inputs["aspect_ratio_mask"] = aspect_ratio_mask
+
+ return model_inputs
+
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
+ """
+ Copied from MllamaForConditionalGeneration::_update_model_kwargs_for_generation: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L2274
+ The only differences are:
+ - add token_idx handling
+ """
+ cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None)
+ model_kwargs = super(MllamaForConditionalGeneration, self)._update_model_kwargs_for_generation(
+ outputs=outputs,
+ model_kwargs=model_kwargs,
+ is_encoder_decoder=is_encoder_decoder,
+ **kwargs,
+ )
+
+ # add cross-attn mask for new token
+ if cross_attention_mask_prev is not None:
+ token_idx = model_kwargs.get("token_idx", None)
+ if token_idx is not None:
+ mask = cross_attention_mask_prev[:, token_idx - 2 : token_idx - 1, ...]
+ cross_attention_mask_prev.index_copy_(1, token_idx - 1, mask)
+ model_kwargs["cross_attention_mask"] = cross_attention_mask_prev
+ else:
+ model_kwargs["cross_attention_mask"] = torch.cat(
+ [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1
+ )
+ return model_kwargs
+
+
+class GaudiMllamaVisionModel(MllamaVisionModel):
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ aspect_ratio_ids: torch.Tensor,
+ aspect_ratio_mask: torch.Tensor,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
+ """
+ Copied from MllamaVisionModel::forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/mllama/modeling_mllama.py#L1425
+ The only differences are:
+ - optimize perf of stage "Collect intermediate layer outputs from encoder output"
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape
+
+ pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width)
+ aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1)
+
+ # Patch embedding
+ patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device))
+ hidden_state = patch_embeds.flatten(2).transpose(1, 2)
+
+ # Tile embeddings
+ _, num_patches, dim = hidden_state.shape
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, -1, dim)
+ hidden_state = self.pre_tile_positional_embedding(hidden_state, aspect_ratio_ids)
+
+ # Add cls token
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media * num_tiles, num_patches, dim)
+ hidden_state = self.apply_class_embedding(hidden_state)
+ num_patches += 1
+
+ # Position embeddings
+ hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches, dim)
+ hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
+
+ hidden_state = self.layernorm_pre(hidden_state)
+
+ # Compute the number of tokens to pad
+ num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
+ # Compute padding tuple for pad function
+ padding = (0, 0, 0, num_padding_patches) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
+ # Pad the tensor
+ hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
+ slice_index = -num_padding_patches if num_padding_patches > 0 else None
+
+ # Prepare attention mask
+ attention_mask = aspect_ratio_mask.reshape(batch_size * num_concurrent_media, -1)
+ attention_mask = _prepare_aspect_ratio_attention_mask(
+ aspect_ratio_mask=attention_mask,
+ num_patches=self.num_patches,
+ target_length=hidden_state.shape[2],
+ dtype=self.dtype,
+ )
+
+ # Apply encoder
+ hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
+ output = self.transformer(
+ hidden_state,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ output_attentions=output_attentions,
+ )
+ hidden_state = output[0]
+
+ hidden_state = self.layernorm_post(hidden_state)
+
+ # Apply global encoder
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim
+ )
+ hidden_state = self.post_tile_positional_embedding(hidden_state, aspect_ratio_ids)
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim
+ )
+ global_output = self.global_transformer(
+ hidden_state,
+ attention_mask=attention_mask,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ )
+ hidden_state = global_output[0]
+
+ # Remove padding form hidden state
+ hidden_state = hidden_state.reshape(
+ batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim
+ )
+ hidden_state = hidden_state[:, :, :slice_index]
+ hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim)
+
+ # Collect intermediate layer outputs from encoder output
+ all_intermediate_hidden_states = output[1]
+ intermediate_hidden_states = [
+ hidden_state
+ for idx, hidden_state in enumerate(all_intermediate_hidden_states)
+ if idx in self.intermediate_layers_indices
+ ]
+ intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
+
+ """
+ intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1)
+ intermediate_hidden_states = intermediate_hidden_states[..., self.intermediate_layers_indices]
+ """
+
+ # Remove padding from intermediate hidden states
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
+ batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1
+ )
+ intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
+ batch_size, num_concurrent_media, num_tiles, num_patches, -1
+ )
+
+ # Concatenate final hidden state and intermediate hidden states
+ hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
+
+ if output_hidden_states:
+ hidden_states = tuple(all_intermediate_hidden_states) + tuple(global_output[1])
+ else:
+ hidden_states = None
+
+ if output_attentions:
+ # global transformer in contrast to `self.transformer` doesn't always return hidden states so we might go index out-of-range
+ global_attn = tuple(global_output[2]) if output_hidden_states else tuple(global_output[1])
+ attentions = tuple(output[2]) + global_attn
+ else:
+ attentions = None
+
+ if not return_dict:
+ return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_state,
+ hidden_states=hidden_states,
+ attentions=attentions,
+ )
diff --git a/tests/baselines/Llama_3_2_11B_Vision_Instruct.json b/tests/baselines/Llama_3_2_11B_Vision_Instruct.json
new file mode 100644
index 0000000000..3789c63fa9
--- /dev/null
+++ b/tests/baselines/Llama_3_2_11B_Vision_Instruct.json
@@ -0,0 +1,38 @@
+{
+ "gaudi2": {
+ "image2text_lora_finetune": {
+ "num_train_epochs": 2,
+ "eval_batch_size": 4,
+ "distribution": {
+ "multi_card": {
+ "learning_rate": 5e-5,
+ "train_batch_size": 2,
+ "train_runtime": 470,
+ "train_samples_per_second": 22,
+ "eval_accuracy": 0.6,
+ "extra_arguments": [
+ "--bf16",
+ "--gradient_accumulation_steps 8",
+ "--eval_strategy no",
+ "--save_strategy no",
+ "--warmup_steps 50",
+ "--lr_scheduler_type constant",
+ "--max_grad_norm 0.3",
+ "--logging_steps 1",
+ "--use_hpu_graphs_for_inference",
+ "--lora_rank 8",
+ "--lora_alpha 8",
+ "--lora_dropout 0.1",
+ "--lora_target_modules '.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$'",
+ "--low_cpu_mem_usage True",
+ "--adam_epsilon 1e-08",
+ "--input_column_name image query",
+ "--output_column_name answers",
+ "--remove_unused_columns False",
+ "--max_seq_length 512"
+ ]
+ }
+ }
+ }
+ }
+}
diff --git a/tests/test_examples.py b/tests/test_examples.py
index f24d250880..f84cdc75c6 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -34,6 +34,7 @@
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
+ MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_MAPPING,
)
from transformers.testing_utils import slow
@@ -203,8 +204,8 @@ def is_valid_model_type(model_type: str) -> bool:
),
"run_image2text_lora_finetune": _get_supported_models_for_script(
MODELS_TO_TEST_MAPPING,
- MODEL_MAPPING,
- ["idefics2"],
+ MODEL_FOR_VISION_2_SEQ_MAPPING,
+ ["idefics2", "mllama"],
),
}
@@ -421,10 +422,9 @@ def test(self):
create_clip_roberta_model()
self._install_requirements(example_script.parent / "requirements.txt")
-
- path_to_baseline = BASELINE_DIRECTORY / Path(model_name.split("/")[-1].replace("-", "_")).with_suffix(
- ".json"
- )
+ path_to_baseline = BASELINE_DIRECTORY / Path(
+ model_name.split("/")[-1].replace("-", "_").replace(".", "_")
+ ).with_suffix(".json")
with path_to_baseline.open("r") as json_file:
device = "gaudi2" if IS_GAUDI2 else "gaudi"
baseline = json.load(json_file)[device]
diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py
index 1cb8b95b33..60049bf46e 100644
--- a/tests/test_image_to_text_example.py
+++ b/tests/test_image_to_text_example.py
@@ -20,6 +20,7 @@
("llava-hf/llava-v1.6-vicuna-7b-hf", 1, 35.00608681379742),
("llava-hf/llava-v1.6-vicuna-13b-hf", 1, 23.527610042925),
("HuggingFaceM4/idefics2-8b", 1, 21.89944593215077),
+ ("meta-llama/Llama-3.2-11B-Vision-Instruct", 1, 20.407843538649303),
],
"fp8": [
("llava-hf/llava-1.5-7b-hf", 1, 98.72578382705062),
diff --git a/tests/utils.py b/tests/utils.py
index 18c00a564c..7eab1b06be 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -62,6 +62,7 @@
"protst": [("mila-intel/protst-esm1b-for-sequential-classification", "Habana/gpt2")],
"qwen2": [("Qwen/Qwen2-7B", "Habana/qwen"), ("Qwen/Qwen2-72B", "Habana/qwen")],
"idefics2": [("HuggingFaceM4/idefics2-8b", "Habana/gpt2")],
+ "mllama": [("meta-llama/Llama-3.2-11B-Vision-Instruct", "Habana/gpt2")],
}
MODELS_TO_TEST_FOR_QUESTION_ANSWERING = [